+++ /dev/null
-include(CheckCSourceRuns)
-
-set(AVX_CODE "
- #include <immintrin.h>
- int main()
- {
- __m256 a;
- a = _mm256_set1_ps(0);
- return 0;
- }
-")
-
-set(AVX512_CODE "
- #include <immintrin.h>
- int main()
- {
- __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0);
- __m512i b = a;
- __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
- return 0;
- }
-")
-
-set(AVX2_CODE "
- #include <immintrin.h>
- int main()
- {
- __m256i a = {0};
- a = _mm256_abs_epi16(a);
- __m256i x;
- _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
- return 0;
- }
-")
-
-set(FMA_CODE "
- #include <immintrin.h>
- int main()
- {
- __m256 acc = _mm256_setzero_ps();
- const __m256 d = _mm256_setzero_ps();
- const __m256 p = _mm256_setzero_ps();
- acc = _mm256_fmadd_ps( d, p, acc );
- return 0;
- }
-")
-
-macro(check_sse type flags)
- set(__FLAG_I 1)
- set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
- foreach (__FLAG ${flags})
- if (NOT ${type}_FOUND)
- set(CMAKE_REQUIRED_FLAGS ${__FLAG})
- check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
- if (HAS_${type}_${__FLAG_I})
- set(${type}_FOUND TRUE CACHE BOOL "${type} support")
- set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
- endif()
- math(EXPR __FLAG_I "${__FLAG_I}+1")
- endif()
- endforeach()
- set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
-
- if (NOT ${type}_FOUND)
- set(${type}_FOUND FALSE CACHE BOOL "${type} support")
- set(${type}_FLAGS "" CACHE STRING "${type} flags")
- endif()
-
- mark_as_advanced(${type}_FOUND ${type}_FLAGS)
-endmacro()
-
-# flags are for MSVC only!
-check_sse("AVX" " ;/arch:AVX")
-if (NOT ${AVX_FOUND})
- set(GGML_AVX OFF)
-else()
- set(GGML_AVX ON)
-endif()
-
-check_sse("AVX2" " ;/arch:AVX2")
-check_sse("FMA" " ;/arch:AVX2")
-if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
- set(GGML_AVX2 OFF)
-else()
- set(GGML_AVX2 ON)
-endif()
-
-check_sse("AVX512" " ;/arch:AVX512")
-if (NOT ${AVX512_FOUND})
- set(GGML_AVX512 OFF)
-else()
- set(GGML_AVX512 ON)
-endif()
+++ /dev/null
-#include "ggml-amx.h"
-#include "ggml-amx/common.h"
-#include "ggml-amx/mmq.h"
-#include "ggml-backend-impl.h"
-#include "ggml-impl.h"
-
-#if defined(__gnu_linux__)
-#include <sys/syscall.h>
-#include <unistd.h>
-#endif
-
-#include <cstdlib>
-#include <cstring>
-#include <memory>
-
-#if defined(__AMX_INT8__)
-
-// AMX buffer interface
-static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- free(buffer->context);
-}
-
-static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
- return (void *)(buffer->context);
-}
-
-static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
- memset((char *)tensor->data + offset, value, size);
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- if (qtype_has_amx_kernels(tensor->type)) {
- ggml_backend_amx_convert_weight(tensor, data, offset, size);
- } else {
- memcpy((char *)tensor->data + offset, data, size);
- }
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
- memcpy(data, (const char *)tensor->data + offset, size);
-
- GGML_UNUSED(buffer);
-}
-
-static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
- if (ggml_backend_buffer_is_host(src->buffer)) {
- if (qtype_has_amx_kernels(src->type)) {
- ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
- } else {
- memcpy(dst->data, src->data, ggml_nbytes(src));
- }
- return true;
- }
- return false;
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- memset(buffer->context, value, buffer->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
- /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
- /* .get_base = */ ggml_backend_amx_buffer_get_base,
- /* .init_tensor = */ NULL, // no initialization required
- /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
- /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_amx_buffer_clear,
- /* .reset = */ NULL,
-};
-
-static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "AMX";
-
- GGML_UNUSED(buft);
-}
-
-static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
- if (data == NULL) {
- fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
- return NULL;
- }
-
- return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
-}
-
-static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return TENSOR_ALIGNMENT;
-
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
- return ggml_backend_amx_get_alloc_size(tensor);
-
- GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return false;
-
- GGML_UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
- static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
- /* .is_host = */ ggml_backend_amx_buffer_type_is_host,
- },
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
- /* .context = */ NULL,
- };
-
- return &ggml_backend_buffer_type_amx;
-}
-
-// backend interface
-
-static const char * ggml_backend_amx_name(ggml_backend_t backend) {
- return "AMX";
-
- GGML_UNUSED(backend);
-}
-
-static void ggml_backend_amx_free(ggml_backend_t backend) {
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
- delete ctx;
- delete backend;
-}
-
-static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
-
- for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * node = cgraph->nodes[i];
-
- switch (node->op) {
- case GGML_OP_MUL_MAT:
- ggml_backend_amx_mul_mat(ctx, node);
- break;
-
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- break;
-
- default:
- fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
- GGML_ASSERT(false);
- }
- }
-
- return GGML_STATUS_SUCCESS;
-
- GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i ggml_backend_amx_i = {
- /* .get_name = */ ggml_backend_amx_name,
- /* .free = */ ggml_backend_amx_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_amx_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_amx_guid() {
- static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
- return &guid;
-}
-
-#define ARCH_GET_XCOMP_PERM 0x1022
-#define ARCH_REQ_XCOMP_PERM 0x1023
-#define XFEATURE_XTILECFG 17
-#define XFEATURE_XTILEDATA 18
-
-static bool ggml_amx_init() {
-#if defined(__gnu_linux__)
- if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
- fprintf(stderr, "AMX is not ready to be used!\n");
- return false;
- }
- return true;
-#elif defined(_WIN32)
- return true;
-#endif
-}
-
-ggml_backend_t ggml_backend_amx_init() {
-
- // invoke a Linux system call to request access to AMX features
- ggml_amx_init();
-
- // backend context
- ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
-
- // ggml amx backend
- ggml_backend_t backend = new ggml_backend {
- /* .guid = */ ggml_backend_amx_guid(),
- /* .interface = */ ggml_backend_amx_i,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
- /* .context = */ ctx,
- };
-
- return backend;
-}
-
-bool ggml_backend_is_amx(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
-}
-
-void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
- GGML_ASSERT(ggml_backend_is_amx(backend_amx));
-
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
- ctx->n_threads = n_threads;
-}
-
-// device interface
-
-static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
- return "AMX";
-
- GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
- return "Intel Advanced Matrix Extensions";
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- // TODO
- *free = 0;
- *total = 0;
-
- GGML_UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
- return GGML_BACKEND_DEVICE_TYPE_ACCEL;
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_amx_device_get_name(dev);
- props->description = ggml_backend_amx_device_get_description(dev);
- props->type = ggml_backend_amx_device_get_type(dev);
- ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
- // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
- props->caps = {
- /* .async = */ false,
- /* .host_buffer = */ false,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
- return ggml_backend_amx_init();
-
- GGML_UNUSED(dev);
- GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
- return ggml_backend_amx_buffer_type();
-
- GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-
- // handle only 2d gemm for now
- auto is_contiguous_2d = [](const struct ggml_tensor * t) {
- return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
- };
-
- switch (op->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- return true;
-
- case GGML_OP_MUL_MAT: {
- const struct ggml_tensor * src0 = op->src[0];
- const struct ggml_tensor * src1 = op->src[1];
-
- const enum ggml_type type = src0->type;
- const int64_t ne0 = op->ne[0];
-
- bool is_training = src0->grad || src1->grad;
-
- // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
- // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
- bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
-
- bool can_use_amx =
- is_contiguous_2d(src0) && // src0 must be contiguous
- is_contiguous_2d(src1) && // src1 must be contiguous
- !is_training && // inference only
- src1->type == GGML_TYPE_F32 && // src1 must be float32
- has_amx_kernels && // with amx kernel impls
- ne0 % (TILE_N * 2) == 0; // out_features is 32x
-
- return can_use_amx;
- }
- default:
- return false;
- }
-
- GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
-
- GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
- /* .get_name = */ ggml_backend_amx_device_get_name,
- /* .get_description = */ ggml_backend_amx_device_get_description,
- /* .get_memory = */ ggml_backend_amx_device_get_memory,
- /* .get_type = */ ggml_backend_amx_device_get_type,
- /* .get_props = */ ggml_backend_amx_device_get_props,
- /* .init_backend = */ ggml_backend_amx_device_init,
- /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
- /* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ NULL,
- /* .supports_op = */ ggml_backend_amx_device_supports_op,
- /* .supports_buft = */ ggml_backend_amx_device_supports_buft,
- /* .offload_op = */ NULL,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
- return "AMX";
-
- GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
- return 1;
-
- GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- GGML_ASSERT(index == 0);
-
- static ggml_backend_device ggml_backend_amx_device = {
- /* .iface = */ ggml_backend_amx_device_i,
- /* .reg = */ reg,
- /* .context = */ nullptr,
- };
-
- return &ggml_backend_amx_device;
-
- GGML_UNUSED(reg);
- GGML_UNUSED(index);
-}
-
-static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
- return (void *)ggml_backend_amx_set_n_threads;
- }
- return NULL;
-
- GGML_UNUSED(reg);
- GGML_UNUSED(name);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
- /* .get_name = */ ggml_backend_amx_reg_get_name,
- /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
- /* .get_device = */ ggml_backend_amx_reg_get_device,
- /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_amx_reg(void) {
- static struct ggml_backend_reg ggml_backend_amx_reg = {
- /* .iface = */ ggml_backend_amx_reg_i,
- /* .context = */ NULL,
- };
-
- return &ggml_backend_amx_reg;
-}
-
-#else // if defined(__AMX_INT8__)
-
-ggml_backend_t ggml_backend_amx_init(void) {
- fprintf(stderr, "GGML is not compiled with AMX support!\n");
- return ggml_backend_t{};
-}
-
-void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
- fprintf(stderr, "GGML is not compiled with AMX support!\n");
-
- GGML_UNUSED(backend_amx);
- GGML_UNUSED(n_threads);
-}
-
-#endif
+++ /dev/null
-#include "ggml-impl.h"
-#include "ggml-blas.h"
-#include "ggml-backend-impl.h"
-
-#include <future>
-#include <vector>
-#include <cstring>
-
-#if defined(GGML_USE_ACCELERATE)
-# include <Accelerate/Accelerate.h>
-#elif defined(GGML_BLAS_USE_MKL)
-# include <mkl.h>
-#elif defined(GGML_BLAS_USE_BLIS)
-# include <blis.h>
-#elif defined(GGML_BLAS_USE_NVPL)
-# include <nvpl_blas.h>
-#else
-# include <cblas.h>
-#endif
-
-struct ggml_backend_blas_context {
- int n_threads = GGML_DEFAULT_N_THREADS;
- std::unique_ptr<char[]> work_data;
- size_t work_size = 0;
-#ifndef GGML_USE_OPENMP
- std::vector<std::future<void>> tasks;
-#endif
-};
-
-static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const enum ggml_type type = src0->type;
-
- GGML_ASSERT(ne0 == ne01);
- GGML_ASSERT(ne1 == ne11);
- GGML_ASSERT(ne2 == ne12);
- GGML_ASSERT(ne3 == ne13);
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == ggml_type_size(type));
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- // broadcast factors
- const int64_t r2 = ne12/ne02;
- const int64_t r3 = ne13/ne03;
-
- const int64_t ne_plane = ne01*ne00;
- const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
-
- if (ctx->work_size < desired_wsize) {
- ctx->work_data.reset(new char[desired_wsize]);
- ctx->work_size = desired_wsize;
- }
- void * wdata = ctx->work_data.get();
-
- // convert src0 to float
- if (type != GGML_TYPE_F32) {
- const auto * type_traits = ggml_get_type_traits(type);
- ggml_to_float_t const to_float = type_traits->to_float;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
- float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
-
- const int min_cols_per_thread = 4096;
- const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
- const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
-
-#ifdef GGML_USE_OPENMP
- #pragma omp parallel for num_threads(n_threads)
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
- }
-#else
- for (int i = 1; i < n_threads; i++) {
- const int64_t start = i*ne01/n_threads;
- const int64_t end = (i + 1)*ne01/n_threads;
- if (start < end) {
- ctx->tasks.push_back(std::async(std::launch::async, [=]() {
- for (int64_t i01 = start; i01 < end; i01++) {
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
- }
- }));
- }
- }
- {
- // reuse the current thread for the first task
- const int64_t start = 0;
- const int64_t end = ne01/n_threads;
- for (int64_t i01 = start; i01 < end; i01++) {
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
- }
- }
-#endif
- }
- }
-
-#ifndef GGML_USE_OPENMP
- // wait for all tasks to finish
- for (auto & task : ctx->tasks) {
- task.get();
- }
- ctx->tasks.clear();
-#endif
- }
-
-#if defined(OPENBLAS_VERSION)
- openblas_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_BLIS)
- bli_thread_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_NVPL)
- nvpl_blas_set_num_threads(ctx->n_threads);
-#endif
-
- for (int64_t i13 = 0; i13 < ne13; i13++) {
- for (int64_t i12 = 0; i12 < ne12; i12++) {
- const int64_t i03 = i13/r3;
- const int64_t i02 = i12/r2;
-
- const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
-
- if (type != GGML_TYPE_F32) {
- x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
- }
-
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- ne1, ne01, ne10,
- 1.0f, y, ne10,
- x, ne00,
- 0.0f, d, ne01);
- }
- }
-}
-
-static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(ne0 == ne00);
- GGML_ASSERT(ne1 == ne10);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne02 == ne12);
- GGML_ASSERT(ne3 == ne13);
- GGML_ASSERT(ne03 == ne13);
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == sizeof(float));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- // GGML_ASSERT(nb0 <= nb1);
- // GGML_ASSERT(nb1 <= nb2);
- // GGML_ASSERT(nb2 <= nb3);
-
- // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
- // src0: (k,n)
- // src1: (k,m)
- // dst: (m,n)
- //
- // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
- // Also expressed as (major,minor)
- // a: (m,k): so src1 transposed
- // b: (k,n): so src0
- // c: (m,n)
- //
- // However, if ggml_is_transposed(src1) is true, then
- // src1->data already contains a transposed version, so sgemm mustn't
- // transpose it further.
-
- int n = src0->ne[0];
- int k = src0->ne[1];
- int m = src1->ne[0];
-
- CBLAS_TRANSPOSE transposeA;
- int lda;
-
- if (!ggml_is_transposed(src1)) {
- transposeA = CblasTrans;
- lda = m;
- } else {
- transposeA = CblasNoTrans;
- lda = k;
- }
-
- float * a = (float *) ((char *) src1->data);
- float * b = (float *) ((char *) src0->data);
- float * c = (float *) ((char *) dst->data);
-
- cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
-
- GGML_UNUSED(ctx);
-}
-
-// backend interface
-
-static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
- return "BLAS";
-
- GGML_UNUSED(backend);
-}
-
-static void ggml_backend_blas_free(ggml_backend_t backend) {
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
- delete ctx;
- delete backend;
-}
-
-static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
-
- for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * node = cgraph->nodes[i];
-
- switch (node->op) {
- case GGML_OP_MUL_MAT:
- ggml_backend_blas_mul_mat(ctx, node);
- break;
-
- case GGML_OP_OUT_PROD:
- ggml_backend_blas_out_prod(ctx, node);
- break;
-
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- break;
-
- default:
- GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
- }
- }
-
- return GGML_STATUS_SUCCESS;
-
- GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i blas_backend_i = {
- /* .get_name = */ ggml_backend_blas_get_name,
- /* .free = */ ggml_backend_blas_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_blas_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_blas_guid(void) {
- static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
- return &guid;
-}
-
-ggml_backend_t ggml_backend_blas_init(void) {
- ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
-
- ggml_backend_t backend = new ggml_backend {
- /* .guid = */ ggml_backend_blas_guid(),
- /* .interface = */ blas_backend_i,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
- /* .context = */ ctx,
- };
-
-#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
- if (openblas_get_parallel() != OPENBLAS_OPENMP) {
- GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
- }
-#endif
-
-#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
- GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
-#endif
-
- return backend;
-}
-
-bool ggml_backend_is_blas(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
-}
-
-void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
- GGML_ASSERT(ggml_backend_is_blas(backend_blas));
-
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
- ctx->n_threads = n_threads;
-}
-
-// device interface
-
-static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
- return "BLAS";
-
- GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
- #if defined(GGML_USE_ACCELERATE)
- return "Accelerate";
- #elif defined(GGML_BLAS_USE_MKL)
- return "MKL";
- #elif defined(GGML_BLAS_USE_BLIS)
- return "BLIS";
- #elif defined(GGML_BLAS_USE_NVPL)
- return "NVPL";
- #elif defined(OPENBLAS_VERSION)
- return "OpenBLAS";
- #else
- return "BLAS";
- #endif
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- // TODO
- *free = 0;
- *total = 0;
-
- GGML_UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
- return GGML_BACKEND_DEVICE_TYPE_ACCEL;
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_blas_device_get_name(dev);
- props->description = ggml_backend_blas_device_get_description(dev);
- props->type = ggml_backend_blas_device_get_type(dev);
- ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
- props->caps = {
- /* .async = */ false,
- /* .host_buffer = */ false,
- /* .buffer_from_host_ptr = */ true,
- /* .events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
- return ggml_backend_blas_init();
-
- GGML_UNUSED(dev);
- GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
- return ggml_backend_cpu_buffer_type();
-
- GGML_UNUSED(dev);
-}
-
-static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
- return ggml_backend_cpu_buffer_from_ptr(ptr, size);
-
- GGML_UNUSED(dev);
- GGML_UNUSED(max_tensor_size);
-}
-
-static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- const struct ggml_tensor * src0 = op->src[0];
- const struct ggml_tensor * src1 = op->src[1];
-
- switch (op->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- return true;
-
- case GGML_OP_MUL_MAT:
- {
- // BLAS usually is only faster for large matrices
- const struct ggml_tensor * src0 = op->src[0];
- const struct ggml_tensor * src1 = op->src[1];
-
- const int64_t ne10 = src1->ne[0];
-
- const int64_t ne0 = op->ne[0];
- const int64_t ne1 = op->ne[1];
-
- // TODO: find the optimal value
- const int64_t min_batch = 32;
-
- return ggml_is_contiguous(src0) &&
- ggml_is_contiguous(src1) &&
- src1->type == GGML_TYPE_F32 &&
- (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
- (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
- }
-
- case GGML_OP_OUT_PROD:
- return op->src[0]->type == GGML_TYPE_F32 &&
- op->src[1]->type == GGML_TYPE_F32 &&
- ggml_is_matrix(src0) &&
- ggml_is_matrix(src1) &&
- ggml_is_contiguous(src0) &&
- (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
- (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
-
- default:
- return false;
-
- }
-
- GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- return ggml_backend_buft_is_host(buft);
-
- GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
- /* .get_name = */ ggml_backend_blas_device_get_name,
- /* .get_description = */ ggml_backend_blas_device_get_description,
- /* .get_memory = */ ggml_backend_blas_device_get_memory,
- /* .get_type = */ ggml_backend_blas_device_get_type,
- /* .get_props = */ ggml_backend_blas_device_get_props,
- /* .init_backend = */ ggml_backend_blas_device_init_backend,
- /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
- /* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
- /* .supports_op = */ ggml_backend_blas_device_supports_op,
- /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
- /* .offload_op = */ NULL,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
- return "BLAS";
-
- GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
- return 1;
-
- GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- GGML_ASSERT(index == 0);
-
- static ggml_backend_device ggml_backend_blas_device = {
- /* .iface = */ ggml_backend_blas_device_i,
- /* .reg = */ reg,
- /* .context = */ nullptr,
- };
-
- return &ggml_backend_blas_device;
-
- GGML_UNUSED(reg);
- GGML_UNUSED(index);
-}
-
-static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
- return (void *)ggml_backend_blas_set_n_threads;
- }
- return NULL;
-
- GGML_UNUSED(reg);
- GGML_UNUSED(name);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
- /* .get_name = */ ggml_backend_blas_reg_get_name,
- /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
- /* .get_device = */ ggml_backend_blas_reg_get_device,
- /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_blas_reg(void) {
- static struct ggml_backend_reg ggml_backend_blas_reg = {
- /* .iface = */ ggml_backend_blas_reg_i,
- /* .context = */ NULL,
- };
-
- return &ggml_backend_blas_reg;
-}
+++ /dev/null
-/*
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
- * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
- * IN THE SOFTWARE.
- */
-
-#include "ggml-cann.h"
-
-#include <acl/acl.h>
-#include <stdarg.h>
-
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <mutex>
-
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-#include "ggml-cann/aclnn_ops.h"
-#include "ggml-cann/common.h"
-
-#define GGML_COMMON_DECL_C
-
-#include "ggml-common.h"
-
-#define GGML_CANN_NAME "CANN"
-
-/**
- * @brief Handles CANN errors by printing an error message and aborting.
- *
- * @param stmt The statement that caused the error.
- * @param func The function in which the error occurred.
- * @param file The file in which the error occurred.
- * @param line The line number where the error occurred.
- * @param msg The error message.
- */
-[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
- const char* file, int line, const char* msg) {
- int32_t id = -1;
- aclrtGetDevice(&id);
-
- GGML_LOG_ERROR("CANN error: %s\n", msg);
- GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
- file, line);
- GGML_LOG_ERROR(" %s\n", stmt);
- // abort with GGML_ASSERT to get a stack trace
- GGML_ABORT("CANN error");
-}
-
-/**
- * @brief Sets the device to be used by CANN.
- *
- * @param device The device ID to set.
- */
-void ggml_cann_set_device(const int32_t device) {
- // TODO: uncomment these lines after empty context has fixed.
- // int current_device;
- // ACL_CHECK(aclrtGetDevice(¤t_device));
-
- // if (device == current_device) {
- // return;
- // }
- ACL_CHECK(aclrtSetDevice(device));
-}
-
-/**
- * @brief Retrieves the current device ID.
- *
- * @return The current device ID.
- */
-int32_t ggml_cann_get_device() {
- int32_t id;
- ACL_CHECK(aclrtGetDevice(&id));
- return id;
-}
-
-/**
- * @brief Initialize the CANN device information.
- *
- * This function initializes the CANN device information by obtaining the
- * device count and setting the memory allocation granularity for each device.
- *
- * @return A structure containing the device information.
- */
-static ggml_cann_device_info ggml_cann_init() {
- ggml_cann_device_info info = {};
-
- aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
-
- if (err != ACL_SUCCESS) {
- GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
- __func__, aclGetRecentErrMsg());
- return info;
- }
-
- GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
-
- for (int id = 0; id < info.device_count; ++id) {
- aclrtPhysicalMemProp prop = {};
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
- prop.memAttr = ACL_HBM_MEM_HUGE;
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
- prop.location.id = id;
- prop.reserve = 0;
- ACL_CHECK(aclrtMemGetAllocationGranularity(
- &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
- &info.devices[id].vmm_granularity));
- }
-
- // TODO: add more device info later.
- return info;
-}
-
-/**
- * @brief Retrieve the CANN device information.
- *
- * This function returns a reference to a structure containing the CANN device
- * information. The device information is initialized once and reused on
- * subsequent calls.
- *
- * @return A reference to the structure containing the device information.
- */
-const ggml_cann_device_info& ggml_cann_info() {
- static ggml_cann_device_info info = ggml_cann_init();
- return info;
-}
-
-//#define DEBUG_CANN_MALLOC
-/**
- * @brief A pool of CANN buffers(legacy).
- *
- * This class manages a pool of CANN buffers for a specific device.
- */
-struct ggml_cann_pool_leg : public ggml_cann_pool {
- /**
- * @brief The maximum number of buffers in the pool.
- */
- static const int MAX_BUFFERS = 256;
-
- /**
- * @brief The device ID associated with this buffer pool.
- */
- int device;
-
- /**
- * @brief Structure representing a CANN buffer.
- */
- struct ggml_cann_buffer {
- void* ptr = nullptr; ///< Pointer to the buffer memory.
- size_t size = 0; ///< Size of the buffer.
- };
-
- /**
- * @brief Array of CANN buffers in the pool.
- */
- ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
-
- /**
- * @brief Total size of all buffers in the pool.
- */
- size_t pool_size = 0;
-
- /**
- * @brief Constructor to initialize the buffer pool for a specific device.
- *
- * @param device The device ID to associate with this buffer pool.
- */
- explicit ggml_cann_pool_leg(int device) : device(device) {}
-
- /**
- * @brief Destructor to free all buffers in the pool.
- */
- ~ggml_cann_pool_leg() {
- ggml_cann_set_device(device);
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cann_buffer& b = buffer_pool[i];
- if (b.ptr != nullptr) {
- ACL_CHECK(aclrtFree(b.ptr));
- pool_size -= b.size;
- }
- }
- GGML_ASSERT(pool_size == 0);
- }
-
- /**
- * @brief Allocate a buffer of the given size.
- *
- * @param size The size of the buffer to allocate.
- * @param actual_size A pointer to a variable to receive the actual size of
- * the allocated buffer.
- * @return A pointer to the allocated buffer.
- */
- void* alloc(size_t size, size_t* actual_size) override {
-#ifdef DEBUG_CANN_MALLOC
- int nnz = 0;
- size_t max_size = 0;
-#endif
- size_t best_diff = 1ull << 36;
- int ibest = -1;
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cann_buffer& b = buffer_pool[i];
- if (b.ptr != nullptr) {
-#ifdef DEBUG_CANN_MALLOC
- ++nnz;
- if (b.size > max_size) max_size = b.size;
-#endif
- if (b.size >= size) {
- size_t diff = b.size - size;
- if (diff < best_diff) {
- best_diff = diff;
- ibest = i;
- if (!best_diff) {
- void* ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- }
- }
- }
- }
- if (ibest >= 0) {
- ggml_cann_buffer& b = buffer_pool[ibest];
- void* ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- void* ptr;
- size_t look_ahead_size = (size_t)(1.05 * size);
- look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
- ggml_cann_set_device(device);
- ACL_CHECK(
- aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
- *actual_size = look_ahead_size;
- pool_size += look_ahead_size;
-#ifdef DEBUG_CANN_MALLOC
- GGML_LOG_INFO(
- "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
- "requested %u MB\n",
- __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
- (uint32_t)(pool_size / 1024 / 1024),
- (uint32_t)(size / 1024 / 1024));
-#endif
- return ptr;
- }
-
- /**
- * @brief Free a buffer and return it to the pool.
- *
- * @param ptr Pointer to the buffer to free.
- * @param size Size of the buffer to free.
- */
- void free(void* ptr, size_t size) override {
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cann_buffer& b = buffer_pool[i];
- if (b.ptr == nullptr) {
- b.ptr = ptr;
- b.size = size;
- return;
- }
- }
- // memory should always buffered. these memory may still needed by
- // tasks in stream.
- // TODO, fix me.
- GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
- }
-};
-
-/**
- * @brief A pool of CANN buffers with virtual memory.
- *
- * This class manages a pool of CANN buffers with virtual memory for a specific
- * device.
- */
-struct ggml_cann_pool_vmm : public ggml_cann_pool {
- /**
- * @brief The maximum size of the virtual memory pool (32 GB).
- */
- static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
-
- /**
- * @brief The device ID associated with this buffer pool.
- */
- int device;
-
- /**
- * @brief Pointer to the start of the virtual memory pool.
- */
- void* pool_addr = 0;
-
- /**
- * @brief Amount of virtual memory used in the pool.
- */
- size_t pool_used = 0;
-
- /**
- * @brief Total size of the virtual memory pool.
- */
- size_t pool_size = 0;
-
- /**
- * @brief Allocation granularity for the virtual memory pool.
- */
- size_t granularity;
-
- /**
- * @brief Handles for the physical memory allocated.
- */
- std::vector<aclrtDrvMemHandle> handles;
-
- /**
- * @brief Offsets for the mapped memory regions.
- */
- std::vector<void*> map_offsets;
-
- /**
- * @brief Constructor to initialize the buffer pool with virtual memory for
- * a specific device.
- *
- * @param device The device ID to associate with this buffer pool.
- */
- explicit ggml_cann_pool_vmm(int device)
- : device(device),
- granularity(ggml_cann_info().devices[device].vmm_granularity) {}
-
- /**
- * @brief Destructor to free all buffers in the virtual memory pool.
- */
- ~ggml_cann_pool_vmm() {
- if (pool_addr != 0) {
- for (auto& offset : map_offsets) {
- ACL_CHECK(aclrtUnmapMem(offset));
- }
- for (auto& handle : handles) {
- ACL_CHECK(aclrtFreePhysical(handle));
- }
- ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
- }
- }
-
- /**
- * @brief Allocate a buffer of the given size in the virtual memory pool.
- *
- * @param size The size of the buffer to allocate.
- * @param actual_size A pointer to a variable to receive the actual size of
- * the allocated buffer.
- * @return A pointer to the allocated buffer.
- */
- void* alloc(size_t size, size_t* actual_size) override {
- // round up the allocation size to the alignment to ensure that all
- // allocations are aligned for all data types
- const size_t alignment = 128;
- size = alignment * ((size + alignment - 1) / alignment);
-
- size_t avail = pool_size - pool_used;
-
- if (size > avail) {
- // round up to the next multiple of the granularity
- size_t reserve_size = size - avail;
- reserve_size =
- granularity * ((reserve_size + granularity - 1) / granularity);
-
- GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
-
- // allocate more physical memory
- aclrtPhysicalMemProp prop = {};
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
- prop.memAttr = ACL_HBM_MEM_HUGE;
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
- prop.location.id = device;
- prop.reserve = 0;
- aclrtDrvMemHandle handle;
- ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
-
- // reserve virtual address space (if not already reserved)
- if (pool_addr == 0) {
- ACL_CHECK(aclrtReserveMemAddress(
- &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
- }
-
- // map at the end of the pool
- ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
- handle, 0));
-
- handles.push_back(handle);
- map_offsets.push_back((char*)pool_addr + pool_size);
-
- // add to the pool
- pool_size += reserve_size;
-
- // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
- // reserved %llu MB)\n",
- // device, (unsigned long long) (pool_size/1024/1024),
- // (unsigned long long) (reserve_size/1024/1024));
- }
-
- GGML_ASSERT(pool_addr != 0);
-
- void* ptr = (void*)((char*)pool_addr + pool_used);
- *actual_size = size;
- pool_used += size;
-
-#ifdef DEBUG_CANN_MALLOC
- GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
- (unsigned long long)size, (unsigned long long)ptr);
-#endif
- return ptr;
- }
-
- /**
- * @brief Free a buffer and return it to the virtual memory pool.
- *
- * @param ptr Pointer to the buffer to free.
- * @param size Size of the buffer to free.
- */
- void free(void* ptr, size_t size) override {
-#ifdef DEBUG_CANN_MALLOC
- GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
- (unsigned long long)size, (unsigned long long)ptr);
-#endif
-
- pool_used -= size;
-
- // all deallocations must be in reverse order of the allocations
- GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
- }
-};
-
-/**
- * @brief Create a new CANN pool for a specific device.
- *
- * Factory method to create a new CANN pool object based on the device type.
- *
- * @param device The device ID for which to create the pool.
- * @return A unique pointer to the created CANN pool.
- */
-std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
- int device) {
- // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
-}
-
-// cann buffer
-/**
- * @brief Context for managing a CANN buffer associated with a specific device.
- *
- * This structure holds information about a CANN buffer, including the device
- * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
- */
-struct ggml_backend_cann_buffer_context {
- int32_t device; ///< The device ID associated with this buffer context.
- void* dev_ptr =
- nullptr; ///< Pointer to the device memory allocated for the buffer.
-
- /**
- * @brief Constructor to initialize the CANN buffer context.
- *
- * @param device The device ID associated with this buffer context.
- * @param dev_ptr Pointer to the device memory allocated for the buffer.
- */
- ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
- : device(device),
- dev_ptr(dev_ptr) {}
-
- /**
- * @brief Destructor to free the device memory allocated for the buffer.
- */
- ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
-};
-
-/**
- * @brief Check if a buffer is a CANN buffer.
- *
- * This function checks if a given buffer is a CANN buffer by comparing its
- * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
- *
- * @param buffer The buffer to check.
- * @return true if the buffer is a CANN buffer, false otherwise.
- */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
-static bool ggml_backend_buffer_is_cann(
- ggml_backend_buffer_t buffer) {
- return ggml_backend_buft_is_cann(buffer->buft);
-}
-
-/**
- * @brief Free resources associated with a CANN buffer.
- *
- * This function frees the resources associated with a CANN buffer, including
- * its context.
- *
- * @param buffer The CANN buffer to free.
- */
-static void ggml_backend_cann_buffer_free_buffer(
- ggml_backend_buffer_t buffer) {
- ggml_backend_cann_buffer_context* ctx =
- (ggml_backend_cann_buffer_context*)buffer->context;
- delete ctx;
-}
-
-/**
- * @brief Retrieve the base pointer of a CANN buffer.
- *
- * This function returns the base pointer of a CANN buffer, which points to the
- * device memory allocated for the buffer.
- *
- * @param buffer The CANN buffer whose base pointer is to be retrieved.
- * @return A pointer to the base of the device memory allocated for the buffer.
- */
-static void* ggml_backend_cann_buffer_get_base(
- ggml_backend_buffer_t buffer) {
- ggml_backend_cann_buffer_context* ctx =
- (ggml_backend_cann_buffer_context*)buffer->context;
- return ctx->dev_ptr;
-}
-
-/**
- * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
- * processing.
- *
- * This function transforms quantized Q4.0 tensor data into a format suitable
- * for CANN processing. It extracts quantization values and scales from the
- * source data and prepares them in a format expected by CANN operations.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data in Q4.0 format.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
- const void* src,
- void* dst) {
-
- int64_t n_elems = ggml_nelements(tensor);
- int64_t groups = n_elems / QK4_0;
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
-
- uint8_t* quant_offset = (uint8_t*)dst;
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
-
- for (int i = 0; i < groups; i++) {
- const block_q4_0* group =
- (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
- *scale_offset = group->d;
- scale_offset++;
-
- // 0-15
- for (int j = 0; j < QK4_0 / 2; j += 2) {
- (*quant_offset) = (group->qs[j] & 0x0F);
- (*quant_offset) |= ((group->qs[j + 1] << 4));
- quant_offset++;
- }
-
- // 16-31
- for (int j = 0; j < QK4_0 / 2; j += 2) {
- (*quant_offset) = (group->qs[j] >> 4);
- (*quant_offset) |= (group->qs[j + 1] & 0xF0);
- quant_offset++;
- }
- }
-
- // put (uint4b_t -8) into int4b_t
- for (quant_offset = (uint8_t*)dst;
- quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
- (*quant_offset) ^= 0x88;
- }
-}
-
-/**
- * @brief Transform CANN processed data back into quantized Q4.0 format.
- *
- * This function transforms CANN processed data back into quantized Q4.0 format.
- * It reverses the transformation performed by
- * ggml_backend_cann_transform_q4_0(), converting the data back into its
- * original quantized form.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source buffer containing transformed data.
- * @param dst Pointer to the destination buffer where the Q4.0 formatted data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back_q4_0(
- const ggml_tensor* tensor, void* src, void* dst) {
-
- int64_t n_elems = ggml_nelements(tensor);
- int64_t groups = n_elems / QK4_0;
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
-
- uint8_t* quant_offset = (uint8_t*)src;
- uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
-
- for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
- (*quant_offset) ^= 0x88;
- }
- quant_offset = (uint8_t*)src;
-
- for (int i = 0; i < groups; i++) {
- block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
- group->d = *scale_offset;
- scale_offset++;
-
- // 0-15
- for (int j = 0; j < QK4_0 / 2; j += 2) {
- group->qs[j] = ((*quant_offset) & 0x0F);
- group->qs[j + 1] = ((*quant_offset) >> 4);
- quant_offset++;
- }
-
- // 16-31
- for (int j = 0; j < QK4_0 / 2; j += 2) {
- group->qs[j] |= ((*quant_offset) << 4);
- group->qs[j + 1] |= ((*quant_offset) & 0xF0);
- quant_offset++;
- }
- }
-}
-
-/**
- * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
- * processing.
- *
- * This function transforms quantized Q8.0 tensor data into a format suitable
- * for CANN processing. It extracts quantization values and scales from the
- * source data and prepares them in a format expected by CANN operations.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data in Q8.0 format.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
- const void* src,
- void* dst) {
- int64_t n_elems = ggml_nelements(tensor);
- int64_t groups = n_elems / QK8_0;
- size_t quant_bytes = n_elems * sizeof(uint8_t);
-
- uint8_t* quant_offset = (uint8_t*)dst;
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
-
- for (int i = 0; i < groups; i++) {
- const block_q8_0* group =
- (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
- *scale_offset = group->d;
- scale_offset++;
- size_t group_quant_size = QK8_0 * sizeof(uint8_t);
- memcpy(quant_offset, group->qs, group_quant_size);
- quant_offset += group_quant_size;
- }
-}
-
-/**
- * @brief Transform CANN processed data back into quantized Q8.0 format.
- *
- * This function transforms CANN processed data back into quantized Q8.0 format.
- * It reverses the transformation performed by
- * ggml_backend_cann_transform_q8_0(), converting the data back into its
- * original quantized form.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source buffer containing transformed data.
- * @param dst Pointer to the destination buffer where the Q8.0 formatted data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back_q8_0(
- const ggml_tensor* tensor, const void* src, void* dst) {
- int64_t n_elems = ggml_nelements(tensor);
- int64_t groups = n_elems / QK8_0;
- size_t quant_bytes = n_elems * sizeof(uint8_t);
-
- const uint8_t* quant_offset = (const uint8_t*)src;
- const uint16_t* scale_offset =
- (const uint16_t*)((const char*)src + quant_bytes);
-
- for (int i = 0; i < groups; i++) {
- block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
- group->d = *scale_offset;
- scale_offset++;
- size_t group_quant_size = QK8_0 * sizeof(uint8_t);
- memcpy(group->qs, quant_offset, group_quant_size);
- quant_offset += group_quant_size;
- }
-}
-
-/**
- * @brief Transform tensor data based on its type for CANN processing.
- *
- * This function transforms tensor data based on its quantization type for CANN
- * processing. It dispatches the transformation based on the tensor's type to
- * specialized functions handling Q4.0 and Q8.0 formats.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data to be transformed.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform(ggml_tensor* tensor,
- const void* src, void* dst) {
- switch (tensor->type) {
- case GGML_TYPE_Q4_0:
- ggml_backend_cann_transform_q4_0(tensor, src, dst);
- break;
- case GGML_TYPE_Q8_0:
- ggml_backend_cann_transform_q8_0(tensor, src, dst);
- break;
- default:
- break;
- }
-}
-
-/**
- * @brief Transform CANN processed data back into tensor data based on its type.
- *
- * This function transforms CANN processed data back into tensor data based on
- * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
- * transformation based on the tensor's type to specialized functions.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data containing CANN processed data.
- * @param dst Pointer to the destination buffer where transformed tensor data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back(
- const ggml_tensor* tensor, void* src, void* dst) {
- switch (tensor->type) {
- case GGML_TYPE_Q4_0:
- ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
- break;
- case GGML_TYPE_Q8_0:
- ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
- break;
- default:
- break;
- }
-}
-
-/**
- * @brief Check if transformation is needed for a given tensor type.
- *
- * This function checks if transformation is needed for a given tensor type
- * to prepare data for CANN processing.
- *
- * @param type The tensor type to check.
- * @return true if transformation is needed, false otherwise.
- */
-static bool need_transform(ggml_type type) {
- switch (type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q8_0:
- return true;
- default:
- return false;
- }
-}
-
-/**
- * @brief Initialize a tensor using data from a CANN buffer.
- *
- * This function initializes a tensor using data from a CANN buffer.
- * It handles special cases such as views and quantization.
- *
- * @param buffer The CANN buffer from which to initialize the tensor.
- * @param tensor Pointer to the tensor to be initialized.
- */
-static void ggml_backend_cann_buffer_init_tensor(
- ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
- GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
- return;
- }
-
- // TODO: can backend doesn't support quantized yet. Just leave the code
- // here.
- if (ggml_is_quantized(tensor->type)) {
- // Initialize padding to 0 to avoid possible NaN values
- size_t original_size = ggml_nbytes(tensor);
- size_t padded_size =
- ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
- if (padded_size > original_size && tensor->view_src == nullptr) {
- size_t memset_size = padded_size - original_size;
- ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
- memset_size, 0, memset_size));
- }
- }
-}
-
-// TODO: need handle tensor which has paddings.
-/**
- * @brief Set tensor data in a CANN buffer.
- *
- * This function sets tensor data in a CANN buffer, handling transformations
- * if needed based on the tensor's type.
- *
- * @param buffer The CANN buffer where the tensor data will be set.
- * @param tensor Pointer to the tensor whose data will be set.
- * @param data Pointer to the source data to be copied into the tensor.
- * @param offset Offset in the source data from where to start copying.
- * @param size Size of the data to be copied, in bytes.
- */
-static void ggml_backend_cann_buffer_set_tensor(
- ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
- size_t offset, size_t size) {
- ggml_backend_cann_buffer_context *ctx =
- (ggml_backend_cann_buffer_context *)buffer->context;
-
- ggml_cann_set_device(ctx->device);
- // TODO: refer to cann(#6017), it use thread's default stream.
- // For acl, synchronous functions use this default stream.
- // Why aclrtSynchronizeDevice?
-
- if (!need_transform(tensor->type)) {
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
- ACL_MEMCPY_HOST_TO_DEVICE));
- } else {
- void *transform_buffer = malloc(size);
- ggml_backend_cann_transform(tensor, data, transform_buffer);
-
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
- transform_buffer, size,
- ACL_MEMCPY_HOST_TO_DEVICE));
- free(transform_buffer);
- }
-}
-
-/**
- * @brief Get tensor data from a CANN buffer.
- *
- * This function retrieves tensor data from a CANN buffer, handling
- * transformations if needed based on the tensor's type.
- *
- * @param buffer The CANN buffer from which to retrieve tensor data.
- * @param tensor Pointer to the tensor whose data will be retrieved.
- * @param data Pointer to the destination buffer where the tensor data will be
- * copied.
- * @param offset Offset in the destination buffer where to start copying.
- * @param size Size of the data to be copied, in bytes.
- */
-static void ggml_backend_cann_buffer_get_tensor(
- ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
- size_t offset, size_t size) {
- ggml_backend_cann_buffer_context* ctx =
- (ggml_backend_cann_buffer_context*)buffer->context;
-
- ggml_cann_set_device(ctx->device);
-
- if (!need_transform(tensor->type)) {
- ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
- ACL_MEMCPY_DEVICE_TO_HOST));
- } else {
- void* transform_buffer = malloc(size);
- ACL_CHECK(aclrtMemcpy(transform_buffer, size,
- (char*)tensor->data + offset, size,
- ACL_MEMCPY_DEVICE_TO_HOST));
- ggml_backend_cann_transform_back(tensor, transform_buffer, data);
- free(transform_buffer);
- }
-}
-
-/**
- * @brief Copy tensor data between CANN buffers if possible.
- *
- * This function copies tensor data between CANN buffers if the source and
- * destination buffers are CANN buffers and they meet the necessary conditions
- * (same device or devices can access each other).
- *
- * @param buffer The destination CANN buffer where the tensor data will be
- * copied.
- * @param src Pointer to the source tensor whose data will be copied.
- * @param dst Pointer to the destination tensor where the data will be copied.
- * @return true if the copy operation succeeded, false otherwise.
- */
-static bool ggml_backend_cann_buffer_cpy_tensor(
- ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
- if (ggml_backend_buffer_is_cann(src->buffer)) {
- ggml_backend_cann_buffer_context* src_ctx =
- (ggml_backend_cann_buffer_context*)src->buffer->context;
- ggml_backend_cann_buffer_context* dst_ctx =
- (ggml_backend_cann_buffer_context*)buffer->context;
-
- size_t memcpy_size = ggml_nbytes(src);
- // Same device.
- if (src_ctx->device == dst_ctx->device) {
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
- (const char*)src->data, memcpy_size,
- ACL_MEMCPY_DEVICE_TO_DEVICE));
- return true;
- } else {
- // Different device but can access by peer.
- int32_t canAccessPeer = 0;
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
- dst_ctx->device));
- if (canAccessPeer) {
- ggml_cann_set_device(src_ctx->device);
- ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
- (const char*)src->data, memcpy_size,
- ACL_MEMCPY_DEVICE_TO_DEVICE));
- return true;
- }
- }
- }
- return false;
-}
-
-/**
- * @brief Clear a CANN buffer by setting all its memory to a specified value.
- *
- * This function clears a CANN buffer by setting all its memory to a specified
- * value.
- *
- * @param buffer The CANN buffer to be cleared.
- * @param value The value to which each byte in the buffer will be set.
- */
-static void ggml_backend_cann_buffer_clear(
- ggml_backend_buffer_t buffer, uint8_t value) {
- ggml_backend_cann_buffer_context* ctx =
- (ggml_backend_cann_buffer_context*)buffer->context;
-
- ggml_cann_set_device(ctx->device);
- ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
-}
-
-/**
- * @brief Interface for a CANN buffer in the backend.
- *
- * This structure defines function pointers to operations that can be performed
- * on a CANN buffer within the backend.
- */
-static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
- /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
- /* .get_base = */ ggml_backend_cann_buffer_get_base,
- /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_cann_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// cann buffer type
-/**
- * @brief Structure representing context information for a specific backend
- * buffer type.
- */
-struct ggml_backend_cann_buffer_type_context {
- int32_t
- device; /**< Device identifier associated with the buffer context. */
- std::string name; /**< Name associated with the buffer context. */
-};
-
-/**
- * @brief Retrieves the name associated with a CANN buffer type.
- *
- * This function returns the descriptive name associated with the specified
- * CANN buffer type context.
- *
- * @param buft Pointer to the buffer type context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char* ggml_backend_cann_buffer_type_name(
- ggml_backend_buffer_type_t buft) {
- ggml_backend_cann_buffer_type_context* buft_ctx =
- (ggml_backend_cann_buffer_type_context*)buft->context;
-
- return buft_ctx->name.c_str();
-}
-
-/**
- * @brief Allocates a new CANN buffer of the specified type and size.
- *
- * This function allocates a new CANN buffer on the specified device with the
- * given size.
- *
- * @param buft Pointer to the buffer type context.
- * @param size Size in bytes of the buffer to allocate.
- * @return Pointer to the allocated buffer, or nullptr if allocation fails.
- */
-static ggml_backend_buffer_t
-ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
- size_t size) {
- ggml_backend_cann_buffer_type_context* buft_ctx =
- (ggml_backend_cann_buffer_type_context*)buft->context;
-
- ggml_cann_set_device(buft_ctx->device);
-
- size = std::max(size, (size_t)1);
-
- void* dev_ptr;
- aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
- if (err != ACL_SUCCESS) {
- GGML_LOG_ERROR(
- "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
- __func__, size / 1024.0 / 1024.0, buft_ctx->device,
- aclGetRecentErrMsg());
- return nullptr;
- }
-
- ggml_backend_cann_buffer_context* ctx =
- new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
-
- return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
- ctx, size);
-}
-
-/**
- * @brief Retrieves the memory alignment requirement for CANN buffers of this
- * type.
- *
- * This function returns the alignment requirement in bytes for memory allocated
- * by the CANN buffer type.
- *
- * @param buft Pointer to the buffer type context (unused in this
- * implementation).
- * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
- * buffers).
- */
-static size_t ggml_backend_cann_buffer_type_get_alignment(
- ggml_backend_buffer_type_t buft) {
- return 128;
-
- GGML_UNUSED(buft);
-}
-
-/**
- * @brief Calculates the allocation size required for a tensor in a CANN buffer.
- *
- * Computes the total allocation size needed for storing the tensor's data in a
- * CANN buffer, considering any necessary padding or adjustments for quantized
- * types.
- *
- * @param buft Pointer to the buffer type context (unused in this
- * implementation).
- * @param tensor Pointer to the tensor for which the allocation size is
- * calculated.
- * @return The total allocation size in bytes required for the tensor in the
- * CANN buffer.
- */
-static size_t ggml_backend_cann_buffer_type_get_alloc_size(
- ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
- size_t size = ggml_nbytes(tensor);
- int64_t ne0 = tensor->ne[0];
-
- // last line must bigger than 32, because every single op deal at
- // least 32 bytes.
- // TODO: quantized type?
- // int64_t line_size = ne0 * ggml_element_size(tensor);
- // int64_t line_size_align_32 = (line_size + 31) & ~31;
- // size += (line_size_align_32 - line_size);
-
- // TODO: not support quantized yet.
- // TODO: consider un-continue tensor.
- if (ggml_is_quantized(tensor->type)) {
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(
- tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
- }
-
- return size;
-
- GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return false;
-
- GGML_UNUSED(buft);
-}
-
-/**
- * @brief Interface for managing CANN buffer types in the GGML backend.
- *
- * Provides function pointers for allocating, querying properties, and managing
- * memory for CANN buffer types in the GGML backend.
- */
-static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
- /* .get_name = */ ggml_backend_cann_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
- /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
-};
-
-/**
- * @brief Retrieves the CANN buffer type for a specified device.
- *
- * This function initializes and returns the buffer type interface associated
- * with the given device. It ensures thread-safe access using a mutex.
- *
- * @param device The device index for which to retrieve the buffer type.
- * @return A pointer to the buffer type interface for the specified device, or
- * nullptr if the device index is out of range.
- */
-ggml_backend_buffer_type_t
-ggml_backend_cann_buffer_type(int32_t device) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- if (device >= ggml_backend_cann_get_device_count()) {
- return nullptr;
- }
-
- static ggml_backend_buffer_type
- ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
-
- static bool ggml_backend_cann_buffer_type_initialized = false;
-
- if (!ggml_backend_cann_buffer_type_initialized) {
- for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
- ggml_backend_cann_buffer_types[i] = {
- /* .iface = */ ggml_backend_cann_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
- /* .context = */
- new ggml_backend_cann_buffer_type_context{
- i, "CANN" + std::to_string(i)},
- };
- }
- ggml_backend_cann_buffer_type_initialized = true;
- }
-
- return &ggml_backend_cann_buffer_types[device];
-}
-
-/**
- * @brief Retrieves the name associated with a CANN host buffer type.
- *
- * This function returns the descriptive name associated with the specified
- * CANN host buffer type context.
- *
- * @param buft Pointer to the host buffer type context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
- return "CANN_Host";
-
- GGML_UNUSED(buft);
-}
-
-/**
- * @brief Retrieves the name associated with a CANN host buffer.
- *
- * This function returns the descriptive name associated with the specified
- * CANN host buffer context.
- *
- * @param buft Pointer to the host buffer context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
- return "CANN_Host";
-
- GGML_UNUSED(buffer);
-}
-
-/**
- * @brief Free resources associated with a CANN host buffer.
- *
- * This function frees the resources associated with a CANN host buffer, including
- * its context.
- *
- * @param buffer The CANN host buffer to free.
- */
-static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
- ACL_CHECK(aclrtFreeHost(buffer->context));
-}
-
-/**
- * @brief Allocates a new CANN host buffer of the specified size.
- *
- * This function allocates a new CANN host buffer with the given size.
- * @param size Size in bytes of the host buffer to allocate.
- * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
- */
-static void * ggml_cann_host_malloc(size_t size) {
- if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
- return nullptr;
- }
-
- void * hostPtr = nullptr;
- aclError err = aclrtMallocHost((void **) &hostPtr, size);
- if (err != ACL_SUCCESS) {
-
- GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
- size / 1024.0 / 1024.0, aclGetRecentErrMsg());
- return nullptr;
- }
- return hostPtr;
-}
-
-/**
- * @brief Allocates a new CANN host buffer of the specified type and size.
- *
- * @param buft Pointer to the host buffer type context.
- * @param size Size in bytes of the host buffer to allocate.
- * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
- */
-static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- void * hostPtr = ggml_cann_host_malloc(size);
-
- if (hostPtr == nullptr) {
- // fallback to cpu buffer
- return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
- }
-
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
- buffer->buft = buft;
- buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
-
- return buffer;
-}
-
-/**
- * @brief Interface for managing CANN host buffer types in the GGML backend.
- *
- * Provides function pointers for allocating, querying properties, and managing
- * memory for CANN buffer types in the GGML backend.
- */
-ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
- static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
- },
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
- /* .context = */ nullptr,
- };
-
- return &ggml_backend_cann_buffer_type_host;
-}
-
-/**
- * @brief Computes the forward operation for a given tensor using CANN
- * operations.
- *
- * This function selects the appropriate CANN operation based on the type of
- * operation specified in the tensor and performs the computation.
- *
- * @param ctx The CANN context containing necessary resources and
- * configurations.
- * @param dst The destination tensor where the result of the computation will be
- * stored.
- * @return true if the computation was successful; false otherwise.
- */
-static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
- struct ggml_tensor* dst) {
- switch (dst->op) {
- case GGML_OP_REPEAT:
- ggml_cann_repeat(ctx, dst);
- break;
- case GGML_OP_GET_ROWS:
- ggml_cann_get_rows(ctx, dst);
- break;
- case GGML_OP_DUP:
- ggml_cann_dup(ctx, dst);
- break;
- case GGML_OP_ADD:
- ggml_cann_add(ctx, dst);
- break;
- case GGML_OP_ACC:
- ggml_cann_acc(ctx, dst);
- break;
- case GGML_OP_MUL:
- ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
- break;
- case GGML_OP_DIV:
- ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
- break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(dst)) {
- case GGML_UNARY_OP_GELU:
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
- ctx, dst);
- break;
- case GGML_UNARY_OP_SILU:
- ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
- ctx, dst);
- break;
- // TODO: Use faster gelu??
- case GGML_UNARY_OP_GELU_QUICK:
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
- ctx, dst);
- break;
- case GGML_UNARY_OP_TANH:
- ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
- ctx, dst);
- break;
- case GGML_UNARY_OP_RELU:
- ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
- ctx, dst);
- break;
- case GGML_UNARY_OP_HARDSIGMOID:
- ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
- aclnnHardsigmoid>(ctx, dst);
- break;
- case GGML_UNARY_OP_HARDSWISH:
- ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
- aclnnHardswish>(ctx, dst);
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_NORM:
- ggml_cann_norm(ctx, dst);
- break;
- case GGML_OP_GROUP_NORM:
- ggml_cann_group_norm(ctx, dst);
- break;
- case GGML_OP_CONCAT:
- ggml_cann_concat(ctx, dst);
- break;
- case GGML_OP_UPSCALE:
- ggml_cann_upsample_nearest2d(ctx, dst);
- break;
- case GGML_OP_PAD:
- ggml_cann_pad(ctx, dst);
- break;
- case GGML_OP_ARANGE:
- ggml_cann_arange(ctx, dst);
- break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- ggml_cann_timestep_embedding(ctx, dst);
- break;
- case GGML_OP_LEAKY_RELU:
- ggml_cann_leaky_relu(ctx, dst);
- break;
- case GGML_OP_RMS_NORM:
- ggml_cann_rms_norm(ctx, dst);
- break;
- case GGML_OP_MUL_MAT:
- ggml_cann_mul_mat(ctx, dst);
- break;
- case GGML_OP_MUL_MAT_ID:
- return false;
- case GGML_OP_SCALE:
- ggml_cann_scale(ctx, dst);
- break;
- case GGML_OP_SQR:
- ggml_cann_sqr(ctx, dst);
- break;
- case GGML_OP_CLAMP:
- ggml_cann_clamp(ctx, dst);
- break;
- case GGML_OP_CPY:
- ggml_cann_cpy(ctx, dst);
- break;
- case GGML_OP_CONT:
- ggml_cann_dup(ctx, dst);
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- break;
- case GGML_OP_DIAG_MASK_INF:
- ggml_cann_diag_mask(ctx, dst, -INFINITY);
- break;
- case GGML_OP_SOFT_MAX:
- ggml_cann_softmax(ctx, dst);
- break;
- case GGML_OP_ROPE:
- ggml_cann_rope(ctx, dst);
- break;
- case GGML_OP_IM2COL:
- ggml_cann_im2col(ctx, dst);
- break;
- case GGML_OP_POOL_2D:
- ggml_cann_pool2d(ctx, dst);
- break;
- case GGML_OP_SUM_ROWS:
- ggml_cann_sum_rows(ctx, dst);
- break;
- case GGML_OP_ARGSORT:
- ggml_cann_argsort(ctx, dst);
- break;
- default:
- return false;
- }
-
- return true;
-}
-
-// backend
-/**
- * @brief Retrieves the name associated with the CANN backend.
- *
- * This function returns the name assigned to the CANN backend, which is stored
- * in the context of the provided backend structure.
- *
- * @param backend Pointer to the CANN backend structure.
- * @return A pointer to a constant string representing the backend name.
- */
-static const char* ggml_backend_cann_name(ggml_backend_t backend) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
-
- return cann_ctx->name.c_str();
-}
-
-/**
- * @brief Frees resources associated with the CANN backend.
- *
- * This function releases resources associated with the CANN backend context
- * and resets the device associated with the backend to its initial state.
- *
- * @param backend Pointer to the CANN backend structure to be freed.
- */
-static void ggml_backend_cann_free(ggml_backend_t backend) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
- ACL_CHECK(aclrtSynchronizeDevice());
- ACL_CHECK(aclrtResetDevice(cann_ctx->device));
-
- // finalize when last backend freed.
- if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
- ACL_CHECK(aclFinalize());
- }
-
- delete cann_ctx;
- delete backend;
-}
-
-/**
- * @brief Sets tensor data asynchronously in the CANN backend.
- *
- * This function asynchronously sets tensor data in the CANN backend. Depending
- * on the tensor type, it may perform data transformations before copying data
- * to the device.
- *
- * @param backend Pointer to the CANN backend structure.
- * @param tensor Pointer to the tensor structure to set data for.
- * @param data Pointer to the host data to copy to the tensor.
- * @param offset Offset in bytes within the host data.
- * @param size Size of the data to copy in bytes.
- */
-static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
- ggml_tensor *tensor,
- const void *data,
- size_t offset,
- size_t size) {
- ggml_backend_cann_context *cann_ctx =
- (ggml_backend_cann_context *)backend->context;
-
- if (!need_transform(tensor->type)) {
- ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
- size, ACL_MEMCPY_HOST_TO_DEVICE,
- cann_ctx->stream()));
- } else {
- void *transform_buffer = malloc(size);
- ggml_backend_cann_transform(tensor, data, transform_buffer);
-
- ACL_CHECK(aclrtMemcpyAsync(
- (char *)tensor->data + offset, size, transform_buffer, size,
- ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
- free(transform_buffer);
- }
-}
-
-static void ggml_backend_cann_get_tensor_async(
- ggml_backend_t backend, const ggml_tensor *tensor, void *data,
- size_t offset, size_t size) {
- ggml_backend_cann_context *cann_ctx =
- (ggml_backend_cann_context *)backend->context;
- ggml_backend_buffer_t buf =
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
- "unsupported buffer type");
-
- if (!need_transform(tensor->type)) {
- ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
- size, ACL_MEMCPY_DEVICE_TO_HOST,
- cann_ctx->stream()));
- } else {
- void *transform_buffer = malloc(size);
- ACL_CHECK(aclrtMemcpyAsync(
- transform_buffer, size, (char *)tensor->data + offset, size,
- ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
- ggml_backend_cann_transform_back(tensor, transform_buffer, data);
- free(transform_buffer);
- }
-}
-
-/**
- * @brief Asynchronously copies tensor data between CANN backends.
- *
- * This function copies tensor data asynchronously between two CANN backends. It
- * checks if both tensors reside in CANN buffers and whether the devices support
- * peer-to-peer access for direct copying. If not, it returns false.
- *
- * @param backend_src Pointer to the source CANN backend structure.
- * @param backend_dst Pointer to the destination CANN backend structure.
- * @param src Pointer to the source tensor to copy data from.
- * @param dst Pointer to the destination tensor to copy data to.
- * @return true if the copy operation succeeds, false otherwise.
- */
-static bool ggml_backend_cann_cpy_tensor_async(
- ggml_backend_t backend_src, ggml_backend_t backend_dst,
- const ggml_tensor* src, ggml_tensor* dst) {
- GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
- ggml_backend_is_cann(backend_dst));
-
- if (!ggml_backend_buffer_is_cann(src->buffer) ||
- !ggml_backend_buffer_is_cann(dst->buffer)) {
- return false;
- }
-
- ggml_backend_buffer_t buf_src =
- src->view_src ? src->view_src->buffer : src->buffer;
- ggml_backend_buffer_t buf_dst =
- dst->view_src ? dst->view_src->buffer : dst->buffer;
-
- ggml_backend_cann_context* cann_ctx_src =
- (ggml_backend_cann_context*)backend_src->context;
- ggml_backend_cann_context* cann_ctx_dst =
- (ggml_backend_cann_context*)backend_dst->context;
-
- size_t copy_size = ggml_nbytes(dst);
- if (backend_src != backend_dst) {
- ggml_backend_cann_buffer_context* buf_ctx_src =
- (ggml_backend_cann_buffer_context*)buf_src->context;
- ggml_backend_cann_buffer_context* buf_ctx_dst =
- (ggml_backend_cann_buffer_context*)buf_dst->context;
-
- GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
- GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
-
- int32_t canAccessPeer = 0;
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
- cann_ctx_dst->device));
- if (!canAccessPeer) {
- return false;
- }
-
- // need open both directions for memcpyasync between devices.
- ggml_cann_set_device(cann_ctx_dst->device);
- ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
- ggml_cann_set_device(cann_ctx_src->device);
- ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
-
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
- ACL_MEMCPY_DEVICE_TO_DEVICE,
- cann_ctx_src->stream()));
-
- //TODO: workaround for Event didn`t work here.
- aclrtSynchronizeStream(cann_ctx_src->stream());
- } else {
- // src and dst are on the same backend
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
- ACL_MEMCPY_DEVICE_TO_DEVICE,
- cann_ctx_dst->stream()));
- }
-
- return true;
-}
-
-/**
- * @brief Synchronizes a CANN backend.
- *
- * This function synchronizes the specified CANN backend by waiting for all
- * operations in its associated stream to complete.
- *
- * @param backend Pointer to the CANN backend structure to synchronize.
- */
-static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
-
- ggml_cann_set_device(cann_ctx->device);
-
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-}
-
-/**
- * @brief Computes a computational graph using a CANN backend.
- *
- * This function computes the operations defined in the computational graph
- * using the specified CANN backend.
- *
- * @param backend Pointer to the CANN backend structure to use for computation.
- * @param cgraph Pointer to the computational graph structure containing nodes
- * representing operations to be computed.
- * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
- * completes successfully, otherwise an appropriate error status.
- */
-static enum ggml_status ggml_backend_cann_graph_compute(
- ggml_backend_t backend, ggml_cgraph* cgraph) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
-
- ggml_cann_set_device(cann_ctx->device);
-
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor* node = cgraph->nodes[i];
-
- if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
- continue;
- }
-
- bool ok = ggml_cann_compute_forward(*cann_ctx, node);
-
- if (!ok) {
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
- node->name, ggml_op_name(node->op));
- }
- GGML_ASSERT(ok);
- }
-
- return GGML_STATUS_SUCCESS;
-}
-
-/**
- * @brief Checks if the CANN backend supports a specific operation.
- *
- * This function checks whether the specified operation is supported by the
- * CANN backend.
- *
- * @param backend Pointer to the CANN backend structure to check support for
- * the operation.
- * @param op Pointer to the tensor representing the operation to check.
- * @return bool Returns true if the operation is supported by the backend,
- * otherwise false.
- */
-static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
- const ggml_tensor* op) {
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_HARDSIGMOID:
- case GGML_UNARY_OP_HARDSWISH:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_TANH:
- return true;
- default:
- return false;
- }
- case GGML_OP_MUL_MAT: {
- switch (op->src[0]->type) {
- case GGML_TYPE_F16:
- case GGML_TYPE_F32:
- case GGML_TYPE_Q8_0:
- // TODO: fix me
- // Current groupsize should not be greater than k-1 in
- // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
- case GGML_TYPE_Q4_0:
- return true;
- default:
- return false;
- }
- }
- case GGML_OP_MUL_MAT_ID:
- return false;
- // embedding
- case GGML_OP_GET_ROWS: {
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q8_0:
- return true;
- default:
- return false;
- }
- } break;
- case GGML_OP_CPY: {
- switch (op->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q4_0:
- return true;
- default:
- return false;
- }
- }
- case GGML_OP_DUP:
- case GGML_OP_REPEAT:
- case GGML_OP_CONCAT:
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_NORM:
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_RMS_NORM:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_CLAMP:
- case GGML_OP_CONT:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_ROPE:
- case GGML_OP_IM2COL:
- case GGML_OP_POOL_2D:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_ARGSORT:
- case GGML_OP_ACC:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
- case GGML_OP_ARANGE:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_LEAKY_RELU:
- return true;
- default:
- return false;
- }
-
- GGML_UNUSED(dev);
-}
-
-/**
- * @brief Checks if the backend buffer type is associated with the CANN backend.
- *
- * This function checks whether the provided backend buffer type is associated
- * with the CANN backend based on the comparison of its name retrieval function
- * pointer.
- *
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the buffer type is associated with the CANN
- * backend, otherwise false.
- */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
-}
-
-/**
- * @brief Determines if a tensor operation should be offloaded to the CANN
- * backend.
- *
- * This function checks if a given tensor operation should be offloaded to the
- * CANN backend based on the operation type and the size of the tensor. It
- * returns true if the second dimension (ne[1]) of the tensor is greater than or
- * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
- *
- * @param backend Pointer to the CANN backend.
- * @param op Pointer to the tensor operation to check.
- * @return bool Returns true if the operation should be offloaded, otherwise
- * false.
- */
-static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
- const ggml_tensor* op) {
- const int min_batch_size = 32;
- GGML_UNUSED(dev);
-
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
-}
-
-/**
- * @brief Records an event on the CANN backend stream.
- *
- * This function records the given event on the ACL runtime stream associated
- * with the backend context.
- *
- * @param event Pointer to the event structure to be recorded.
- */
-static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
- ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
-}
-
-/**
- * @brief Waits for a recorded event to complete on the CANN backend stream.
- *
- * This function makes the given backend wait for the event to complete on its
- * ACL runtime stream.
- *
- * @param backend Pointer to the backend structure.
- * @param event Pointer to the event structure that the backend needs to wait
- * for.
- */
-static void ggml_backend_cann_event_wait(ggml_backend_t backend,
- ggml_backend_event_t event) {
- ggml_backend_cann_context* cann_ctx =
- (ggml_backend_cann_context*)backend->context;
- if (ggml_backend_is_cann(backend)) {
- ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
- (aclrtEvent)event->context));
- } else {
- GGML_ABORT("fatal error");
- }
-}
-
-/**
- * @brief Structure defining the interface for the CANN backend.
- *
- * This structure contains function pointers for various operations
- * supported by the CANN backend, including name retrieval, memory
- * management, tensor operations, synchronization, and event handling.
- */
-static const ggml_backend_i ggml_backend_cann_interface = {
- /* .get_name = */ ggml_backend_cann_name,
- /* .free = */ ggml_backend_cann_free,
- /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
- /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
- /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
- /* .synchronize = */ ggml_backend_cann_synchronize,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_cann_graph_compute,
- /* .event_record = */ ggml_backend_cann_event_record,
- /* .event_wait = */ ggml_backend_cann_event_wait,
-};
-
-/**
- * @brief Return the hardcoded GUID for the CANN backend.
- *
- * This function returns a static GUID which uniquely identifies the CANN
- * backend.
- *
- * @return A pointer to the static GUID.
- */
-static ggml_guid_t ggml_backend_cann_guid() {
- static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
- 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
- return &guid;
-}
-
-// backend device
-struct ggml_backend_cann_device_context {
- int device;
- std::string name;
- std::string description;
-};
-
-static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
- return ctx->name.c_str();
-}
-
-static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
- return ctx->description.c_str();
-}
-
-static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
- ggml_backend_cann_get_device_memory(ctx->device, free, total);
-}
-
-static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
- props->name = ggml_backend_cann_device_get_name(dev);
- props->description = ggml_backend_cann_device_get_description(dev);
- props->type = ggml_backend_cann_device_get_type(dev);
- ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
- bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
-
- props->caps = {
- /* .async = */ false,
- /* .host_buffer = */ host_buffer,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ true,
- };
-}
-
-static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
- GGML_UNUSED(params);
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
- return ggml_backend_cann_init(ctx->device);
-}
-
-/**
- * @brief Checks if the CANN backend supports a specific backend buffer type.
- *
- * This function determines whether the CANN backend supports the given backend
- * buffer type by comparing the device context of the backend and buffer type.
- * It returns true if the devices are same between the backend context and
- * buffer type context.
- *
- * @param backend Pointer to the CANN backend.
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the CANN backend supports the buffer type,
- * otherwise false.
- */
-static bool ggml_backend_cann_supports_buft(
- ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- if (ggml_backend_buft_is_cann(buft)) {
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
- ggml_backend_cann_buffer_type_context * buft_ctx =
- (ggml_backend_cann_buffer_type_context *)buft->context;
- return buft_ctx->device == dev_ctx->device;
- }
- return false;
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
- return ggml_backend_cann_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return ggml_backend_cann_host_buffer_type();
-}
-
-/**
- * @brief Creates a new event for the CANN backend device.
- *
- * This function initializes a new event for the CANN backend by setting the
- * device and creating an ACL runtime event. The created event is then wrapped
- * in a ggml_backend_event structure and returned.
- *
- * @param backend Pointer to the CANN backend.
- * @return ggml_backend_event_t Returns a pointer to the new event structure.
- */
-static ggml_backend_event_t ggml_backend_cann_device_event_new(
- ggml_backend_dev_t dev) {
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
-
- ggml_cann_set_device(dev_ctx->device);
-
- aclrtEvent event;
- ACL_CHECK(aclrtCreateEvent(&event));
-
- return new ggml_backend_event{
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
- /* .context = */ event,
- };
-}
-
-/**
- * @brief Frees a CANN backend event.
- *
- * This function destroys the ACL runtime event associated with the given CANN
- * backend event and then deletes the event structure itself.
- *
- * @param event Pointer to the event structure to be freed.
- */
-static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
- ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
-
- delete event;
- GGML_UNUSED(dev);
-}
-
-/**
- * @brief Synchronizes the given event on the CANN backend.
- *
- * This function waits for the specified event to complete on the ACL runtime.
- *
- * @param event Pointer to the event structure to be synchronized.
- */
-static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
- ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
-
- GGML_UNUSED(dev);
-}
-
-static const ggml_backend_device_i ggml_backend_cann_device_interface = {
- /* .get_name = */ ggml_backend_cann_device_get_name,
- /* .get_description = */ ggml_backend_cann_device_get_description,
- /* .get_memory = */ ggml_backend_cann_device_get_memory,
- /* .get_type = */ ggml_backend_cann_device_get_type,
- /* .get_props = */ ggml_backend_cann_device_get_props,
- /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
- /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
- /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
- /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
- /* .supports_op = */ ggml_backend_cann_supports_op,
- /* .supports_buft = */ ggml_backend_cann_supports_buft,
- /* .offload_op = */ ggml_backend_cann_offload_op,
- /* .event_new = */ ggml_backend_cann_device_event_new,
- /* .event_free = */ ggml_backend_cann_device_event_free,
- /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
-};
-
-
-// backend reg
-struct ggml_backend_cann_reg_context {
- std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
- GGML_UNUSED(reg);
- return GGML_CANN_NAME;
-}
-
-static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
- return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
- GGML_ASSERT(index < ctx->devices.size());
- return ctx->devices[index];
-}
-
-static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- GGML_UNUSED(reg);
- GGML_UNUSED(name);
- // reserved for future use
- return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
- /* .get_name = */ ggml_backend_cann_reg_get_name,
- /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
- /* .get_device_get = */ ggml_backend_cann_reg_get_device,
- /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
-};
-
-// backend registry, called only once for cann backend
-ggml_backend_reg_t ggml_backend_cann_reg() {
- static ggml_backend_reg reg;
- static bool initialized = false;
-
- {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- if (!initialized) {
- aclInit(nullptr);
- ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
-
- for (int i = 0; i < ggml_cann_info().device_count; i++) {
- ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
- dev_ctx->description = aclrtGetSocName();
- dev_ctx->device = i;
- dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
- ggml_cann_set_device(i);
- ggml_backend_dev_t dev = new ggml_backend_device {
- /* .interface = */ ggml_backend_cann_device_interface,
- /* .reg = */ ®,
- /* .context = */ dev_ctx
- };
- ctx->devices.push_back(dev);
- }
-
- reg = ggml_backend_reg {
- /* .interface = */ ggml_backend_cann_reg_interface,
- /* .context = */ ctx
- };
- }
-
- initialized = true;
- }
-
- return ®
-}
-
-ggml_backend_t ggml_backend_cann_init(int32_t device) {
- aclInit(nullptr);
- if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
- GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
- return nullptr;
- }
-
- ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
- if (ctx == nullptr) {
- GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
- return nullptr;
- }
- ggml_cann_set_device(ctx->device);
- ggml_backend_t cann_backend =
- new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
- /* .interface = */ ggml_backend_cann_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
- /* .context = */ ctx};
-
- return cann_backend;
-}
-
-bool ggml_backend_is_cann(ggml_backend_t backend) {
- return backend != NULL &&
- ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
-}
-
-int32_t ggml_backend_cann_get_device_count() {
- return ggml_cann_info().device_count;
-}
-
-void ggml_backend_cann_get_device_description(
- int32_t device, char* description, size_t description_size) {
- ggml_cann_set_device(device);
- const char* soc_name = aclrtGetSocName();
- snprintf(description, description_size, "%s", soc_name);
-}
-
-void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
- size_t* total) {
- ggml_cann_set_device(device);
- ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
-}
+++ /dev/null
-#pragma once
-
-// GGML CPU internal header
-
-#include "ggml.h"
-#include "ggml-impl.h"
-#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
-//#include <stddef.h>
-#include <stdbool.h>
-#include <string.h> // memcpy
-#include <math.h> // fabsf
-
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#if defined(_MSC_VER)
-
-#define m512bh(p) p
-#define m512i(p) p
-
-#else
-
-#define m512bh(p) (__m512bh)(p)
-#define m512i(p) (__m512i)(p)
-
-#endif
-
-/**
- * Converts brain16 to float32.
- *
- * The bfloat16 floating point format has the following structure:
- *
- * ┌sign
- * │
- * │ ┌exponent
- * │ │
- * │ │ ┌mantissa
- * │ │ │
- * │┌──┴───┐┌─┴───┐
- * 0b0000000000000000 brain16
- *
- * Since bf16 has the same number of exponent bits as a 32bit float,
- * encoding and decoding numbers becomes relatively straightforward.
- *
- * ┌sign
- * │
- * │ ┌exponent
- * │ │
- * │ │ ┌mantissa
- * │ │ │
- * │┌──┴───┐┌─┴───────────────────┐
- * 0b00000000000000000000000000000000 IEEE binary32
- *
- * For comparison, the standard fp16 format has fewer exponent bits.
- *
- * ┌sign
- * │
- * │ ┌exponent
- * │ │
- * │ │ ┌mantissa
- * │ │ │
- * │┌─┴─┐┌─┴──────┐
- * 0b0000000000000000 IEEE binary16
- *
- * @see IEEE 754-2008
- */
-static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
- union {
- float f;
- uint32_t i;
- } u;
- u.i = (uint32_t)h.bits << 16;
- return u.f;
-}
-
-/**
- * Converts float32 to brain16.
- *
- * This is binary identical with Google Brain float conversion.
- * Floats shall round to nearest even, and NANs shall be quiet.
- * Subnormals aren't flushed to zero, except perhaps when used.
- * This code should vectorize nicely if using modern compilers.
- */
-static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
- ggml_bf16_t h;
- union {
- float f;
- uint32_t i;
- } u;
- u.f = s;
- if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
- h.bits = (u.i >> 16) | 64; /* force to quiet */
- return h;
- }
- h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
- return h;
-}
-
-#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
-#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
-
-// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
-#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
-#ifndef __FMA__
-#define __FMA__
-#endif
-#ifndef __F16C__
-#define __F16C__
-#endif
-#endif
-
-// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
-#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
-#ifndef __SSE3__
-#define __SSE3__
-#endif
-#ifndef __SSSE3__
-#define __SSSE3__
-#endif
-#endif
-
-#if defined(__ARM_FEATURE_SVE)
-#include <arm_sve.h>
-#include <sys/prctl.h>
-#endif
-
-// 16-bit float
-// on Arm, we use __fp16
-// on x86, we use uint16_t
-#if defined(__ARM_NEON)
-
-// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
-//
-// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
-//
-#include <arm_neon.h>
-
-#ifdef _MSC_VER
-
-typedef uint16_t ggml_fp16_internal_t;
-
-#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
-
-#else
-
-typedef __fp16 ggml_fp16_internal_t;
-
-#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
-
-#endif // _MSC_VER
-
-#if !defined(__aarch64__)
-
-// 32-bit ARM compatibility
-
-// vaddlvq_s16
-// vpaddq_s16
-// vpaddq_s32
-// vaddvq_s32
-// vaddvq_f32
-// vmaxvq_f32
-// vcvtnq_s32_f32
-// vzip1_u8
-// vzip2_u8
-
-inline static int32_t vaddlvq_s16(int16x8_t v) {
- int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
- return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
- return vcombine_s16(a0, b0);
-}
-
-inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
- return vcombine_s32(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-
-inline static float vaddvq_f32(float32x4_t v) {
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
-}
-
-inline static float vmaxvq_f32(float32x4_t v) {
- return
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
-}
-
-inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
- int32x4_t res;
-
- res[0] = roundf(vgetq_lane_f32(v, 0));
- res[1] = roundf(vgetq_lane_f32(v, 1));
- res[2] = roundf(vgetq_lane_f32(v, 2));
- res[3] = roundf(vgetq_lane_f32(v, 3));
-
- return res;
-}
-
-inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
- uint8x8_t res;
-
- res[0] = a[0]; res[1] = b[0];
- res[2] = a[1]; res[3] = b[1];
- res[4] = a[2]; res[5] = b[2];
- res[6] = a[3]; res[7] = b[3];
-
- return res;
-}
-
-inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
- uint8x8_t res;
-
- res[0] = a[4]; res[1] = b[4];
- res[2] = a[5]; res[3] = b[5];
- res[4] = a[6]; res[5] = b[6];
- res[6] = a[7]; res[7] = b[7];
-
- return res;
-}
-
-// vld1q_s16_x2
-// vld1q_u8_x2
-// vld1q_u8_x4
-// vld1q_s8_x2
-// vld1q_s8_x4
-// TODO: double-check these work correctly
-
-typedef struct ggml_int16x8x2_t {
- int16x8_t val[2];
-} ggml_int16x8x2_t;
-
-inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
- ggml_int16x8x2_t res;
-
- res.val[0] = vld1q_s16(ptr + 0);
- res.val[1] = vld1q_s16(ptr + 8);
-
- return res;
-}
-
-typedef struct ggml_uint8x16x2_t {
- uint8x16_t val[2];
-} ggml_uint8x16x2_t;
-
-inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
- ggml_uint8x16x2_t res;
-
- res.val[0] = vld1q_u8(ptr + 0);
- res.val[1] = vld1q_u8(ptr + 16);
-
- return res;
-}
-
-typedef struct ggml_uint8x16x4_t {
- uint8x16_t val[4];
-} ggml_uint8x16x4_t;
-
-inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
- ggml_uint8x16x4_t res;
-
- res.val[0] = vld1q_u8(ptr + 0);
- res.val[1] = vld1q_u8(ptr + 16);
- res.val[2] = vld1q_u8(ptr + 32);
- res.val[3] = vld1q_u8(ptr + 48);
-
- return res;
-}
-
-typedef struct ggml_int8x16x2_t {
- int8x16_t val[2];
-} ggml_int8x16x2_t;
-
-inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
- ggml_int8x16x2_t res;
-
- res.val[0] = vld1q_s8(ptr + 0);
- res.val[1] = vld1q_s8(ptr + 16);
-
- return res;
-}
-
-typedef struct ggml_int8x16x4_t {
- int8x16_t val[4];
-} ggml_int8x16x4_t;
-
-inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
- ggml_int8x16x4_t res;
-
- res.val[0] = vld1q_s8(ptr + 0);
- res.val[1] = vld1q_s8(ptr + 16);
- res.val[2] = vld1q_s8(ptr + 32);
- res.val[3] = vld1q_s8(ptr + 48);
-
- return res;
-}
-
-// NOTE: not tested
-inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
- int8x16_t res;
-
- res[ 0] = a[b[ 0]];
- res[ 1] = a[b[ 1]];
- res[ 2] = a[b[ 2]];
- res[ 3] = a[b[ 3]];
- res[ 4] = a[b[ 4]];
- res[ 5] = a[b[ 5]];
- res[ 6] = a[b[ 6]];
- res[ 7] = a[b[ 7]];
- res[ 8] = a[b[ 8]];
- res[ 9] = a[b[ 9]];
- res[10] = a[b[10]];
- res[11] = a[b[11]];
- res[12] = a[b[12]];
- res[13] = a[b[13]];
- res[14] = a[b[14]];
- res[15] = a[b[15]];
-
- return res;
-}
-
-// NOTE: not tested
-inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
- uint8x16_t res;
-
- res[ 0] = a[b[ 0]];
- res[ 1] = a[b[ 1]];
- res[ 2] = a[b[ 2]];
- res[ 3] = a[b[ 3]];
- res[ 4] = a[b[ 4]];
- res[ 5] = a[b[ 5]];
- res[ 6] = a[b[ 6]];
- res[ 7] = a[b[ 7]];
- res[ 8] = a[b[ 8]];
- res[ 9] = a[b[ 9]];
- res[10] = a[b[10]];
- res[11] = a[b[11]];
- res[12] = a[b[12]];
- res[13] = a[b[13]];
- res[14] = a[b[14]];
- res[15] = a[b[15]];
-
- return res;
-}
-
-#else
-
-#define ggml_int16x8x2_t int16x8x2_t
-#define ggml_uint8x16x2_t uint8x16x2_t
-#define ggml_uint8x16x4_t uint8x16x4_t
-#define ggml_int8x16x2_t int8x16x2_t
-#define ggml_int8x16x4_t int8x16x4_t
-
-#define ggml_vld1q_s16_x2 vld1q_s16_x2
-#define ggml_vld1q_u8_x2 vld1q_u8_x2
-#define ggml_vld1q_u8_x4 vld1q_u8_x4
-#define ggml_vld1q_s8_x2 vld1q_s8_x2
-#define ggml_vld1q_s8_x4 vld1q_s8_x4
-#define ggml_vqtbl1q_s8 vqtbl1q_s8
-#define ggml_vqtbl1q_u8 vqtbl1q_u8
-
-#endif // !defined(__aarch64__)
-
-#if !defined(__ARM_FEATURE_DOTPROD)
-
-inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
-
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
-}
-
-#else
-
-#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
-
-#endif // !defined(__ARM_FEATURE_DOTPROD)
-
-#endif // defined(__ARM_NEON)
-
-#if defined(__ARM_NEON) && !defined(_MSC_VER)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
- ggml_fp16_internal_t tmp;
- memcpy(&tmp, &h, sizeof(ggml_fp16_t));
- return (float)tmp;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
- ggml_fp16_t res;
- ggml_fp16_internal_t tmp = f;
- memcpy(&res, &tmp, sizeof(ggml_fp16_t));
- return res;
-}
-
-#else
-
-#ifdef __wasm_simd128__
-#include <wasm_simd128.h>
-#else
-#ifdef __POWER9_VECTOR__
-#include <altivec.h>
-#undef bool
-#define bool _Bool
-#else
-#if defined(_MSC_VER) || defined(__MINGW32__)
-#include <intrin.h>
-#else
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
-#if !defined(__riscv)
-#include <immintrin.h>
-#endif
-#endif
-#endif
-#endif
-#endif
-
-#ifdef __riscv_v_intrinsic
-#include <riscv_vector.h>
-#endif
-
-#if defined(__loongarch64)
-#if defined(__loongarch_asx)
-#include <lasxintrin.h>
-#endif
-#if defined(__loongarch_sx)
-#include <lsxintrin.h>
-#endif
-#endif
-
-#if defined(__loongarch_asx)
-
-typedef union {
- int32_t i;
- float f;
-} ft_union;
-
-/* float type data load instructions */
-static __m128 __lsx_vreplfr2vr_s(float val) {
- ft_union fi_tmpval = {.f = val};
- return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
-}
-
-static __m256 __lasx_xvreplfr2vr_s(float val) {
- ft_union fi_tmpval = {.f = val};
- return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
-}
-#endif
-
-#ifdef __F16C__
-
-#ifdef _MSC_VER
-#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
-#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
-#else
-#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
-#endif
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-/* the inline asm below is about 12% faster than the lookup method */
-#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
- register float f;
- register double d;
- __asm__(
- "mtfprd %0,%2\n"
- "xscvhpdp %0,%0\n"
- "frsp %1,%0\n" :
- /* temp */ "=d"(d),
- /* out */ "=f"(f):
- /* in */ "r"(h));
- return f;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
- register double d;
- register ggml_fp16_t r;
- __asm__( /* xscvdphp can work on double or single precision */
- "xscvdphp %0,%2\n"
- "mffprd %1,%0\n" :
- /* temp */ "=d"(d),
- /* out */ "=r"(r):
- /* in */ "f"(f));
- return r;
-}
-
-#else
-
-// FP16 <-> FP32
-// ref: https://github.com/Maratyszcza/FP16
-
-static inline float fp32_from_bits(uint32_t w) {
- union {
- uint32_t as_bits;
- float as_value;
- } fp32;
- fp32.as_bits = w;
- return fp32.as_value;
-}
-
-static inline uint32_t fp32_to_bits(float f) {
- union {
- float as_value;
- uint32_t as_bits;
- } fp32;
- fp32.as_value = f;
- return fp32.as_bits;
-}
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
- const uint32_t w = (uint32_t) h << 16;
- const uint32_t sign = w & UINT32_C(0x80000000);
- const uint32_t two_w = w + w;
-
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
- const float exp_scale = 0x1.0p-112f;
-#else
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
-#endif
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
-
- const uint32_t magic_mask = UINT32_C(126) << 23;
- const float magic_bias = 0.5f;
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
-
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
- const uint32_t result = sign |
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
- return fp32_from_bits(result);
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
- const float scale_to_inf = 0x1.0p+112f;
- const float scale_to_zero = 0x1.0p-110f;
-#else
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
-#endif
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
-
- const uint32_t w = fp32_to_bits(f);
- const uint32_t shl1_w = w + w;
- const uint32_t sign = w & UINT32_C(0x80000000);
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
- if (bias < UINT32_C(0x71000000)) {
- bias = UINT32_C(0x71000000);
- }
-
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
- const uint32_t bits = fp32_to_bits(base);
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
- const uint32_t nonsign = exp_bits + mantissa_bits;
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
-}
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#endif // __F16C__
-
-#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
-
-#ifdef __ARM_FEATURE_SVE
-#include <arm_sve.h>
-#endif // __ARM_FEATURE_SVE
-
-// precomputed f32 table for f16 (256 KB)
-// defined in ggml.c, initialized in ggml_init()
-extern float ggml_table_f32_f16[1 << 16];
-
-// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
-// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
-// This is also true for POWER9.
-#if !defined(GGML_FP16_TO_FP32)
-inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
- uint16_t s;
- memcpy(&s, &f, sizeof(uint16_t));
- return ggml_table_f32_f16[s];
-}
-
-#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
-#endif
-
-#if !defined(GGML_FP32_TO_FP16)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-#endif
-
-#ifdef __cplusplus
-}
-#endif
+++ /dev/null
-#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
-#define _USE_MATH_DEFINES // For M_PI on MSVC
-
-#include "ggml-aarch64.h"
-#include "ggml-backend-impl.h"
-#include "ggml-backend.h"
-#include "ggml-cpu-impl.h"
-#include "ggml-cpu.h"
-#include "ggml-impl.h"
-#include "ggml-quants.h"
-#include "ggml.h"
-
-#if defined(_MSC_VER) || defined(__MINGW32__)
-#include <malloc.h> // using malloc.h with MSC/MINGW
-#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
-#include <alloca.h>
-#endif
-
-#include <assert.h>
-#include <errno.h>
-#include <time.h>
-#include <math.h>
-#include <stdlib.h>
-#include <string.h>
-#include <stdint.h>
-#include <inttypes.h>
-#include <stdio.h>
-#include <float.h>
-#include <limits.h>
-#include <stdarg.h>
-#include <signal.h>
-#if defined(__gnu_linux__)
-#include <syscall.h>
-#endif
-
-#ifdef GGML_USE_OPENMP
-#include <omp.h>
-#endif
-
-#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
-#undef GGML_USE_LLAMAFILE
-#endif
-
-#ifdef GGML_USE_LLAMAFILE
-#include <llamafile/sgemm.h>
-#endif
-
-#if defined(_MSC_VER)
-// disable "possible loss of data" to avoid hundreds of casts
-// we should just be careful :)
-#pragma warning(disable: 4244 4267)
-
-// disable POSIX deprecation warnings
-// these functions are never going away, anyway
-#pragma warning(disable: 4996)
-
-// unreachable code because of multiple instances of code after GGML_ABORT
-#pragma warning(disable: 4702)
-#endif
-
-// Note: once we move threading into a separate C++ file
-// will use std::hardware_destructive_interference_size instead of hardcoding it here
-// and we'll use C++ attribute syntax.
-#define GGML_CACHE_LINE 64
-
-#if defined(__clang__) || defined(__GNUC__)
-#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
-#endif
-
-#if defined(__has_feature)
-#if __has_feature(thread_sanitizer)
-#define GGML_TSAN_ENABLED 1
-#endif
-#else // __has_feature
-#if defined(__SANITIZE_THREAD__)
-#define GGML_TSAN_ENABLED 1
-#endif
-#endif // __has_feature
-
-#define UNUSED GGML_UNUSED
-#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
-
-#if defined(GGML_USE_ACCELERATE)
-#include <Accelerate/Accelerate.h>
-#endif
-
-// floating point type used to accumulate sums
-typedef double ggml_float;
-
-#define GGML_GELU_FP16
-#define GGML_GELU_QUICK_FP16
-
-#define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL 2
-#define GGML_VEC_MAD_UNROLL 32
-
-//
-// global data
-//
-
-// precomputed gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
-
-// precomputed quick gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
-
-// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
-float ggml_table_f32_f16[1 << 16];
-
-#if defined(__ARM_ARCH)
-struct ggml_arm_arch_features_type {
- int has_neon;
- int has_i8mm;
- int has_sve;
- int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, 0};
-#endif
-
-
-#if defined(_WIN32)
-
-#define WIN32_LEAN_AND_MEAN
-#ifndef NOMINMAX
- #define NOMINMAX
-#endif
-#include <windows.h>
-
-
-#if !defined(__clang__)
-#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
-
-typedef volatile LONG atomic_int;
-typedef atomic_int atomic_bool;
-typedef atomic_int atomic_flag;
-
-#define ATOMIC_FLAG_INIT 0
-
-typedef enum {
- memory_order_relaxed,
- memory_order_consume,
- memory_order_acquire,
- memory_order_release,
- memory_order_acq_rel,
- memory_order_seq_cst
-} memory_order;
-
-static void atomic_store(atomic_int * ptr, LONG val) {
- InterlockedExchange(ptr, val);
-}
-static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
- // TODO: add support for explicit memory order
- InterlockedExchange(ptr, val);
-}
-static LONG atomic_load(atomic_int * ptr) {
- return InterlockedCompareExchange(ptr, 0, 0);
-}
-static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
- // TODO: add support for explicit memory order
- return InterlockedCompareExchange(ptr, 0, 0);
-}
-static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
- return InterlockedExchangeAdd(ptr, inc);
-}
-static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
- // TODO: add support for explicit memory order
- return InterlockedExchangeAdd(ptr, inc);
-}
-static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
- return InterlockedExchange(ptr, 1);
-}
-static void atomic_flag_clear(atomic_flag * ptr) {
- InterlockedExchange(ptr, 0);
-}
-static void atomic_thread_fence(memory_order mo) {
- MemoryBarrier();
-}
-#else // clang
-#include <stdatomic.h>
-#endif
-
-typedef HANDLE pthread_t;
-
-typedef DWORD thread_ret_t;
-static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
- (void) unused;
- HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
- if (handle == NULL)
- {
- return EAGAIN;
- }
-
- *out = handle;
- return 0;
-}
-
-static int pthread_join(pthread_t thread, void * unused) {
- (void) unused;
- int ret = (int) WaitForSingleObject(thread, INFINITE);
- CloseHandle(thread);
- return ret;
-}
-
-static int sched_yield (void) {
- Sleep (0);
- return 0;
-}
-#else
-
-#include <pthread.h>
-#include <stdatomic.h>
-#include <sched.h>
-#if defined(__FreeBSD__)
-#include <pthread_np.h>
-#endif
-
-typedef void * thread_ret_t;
-
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <unistd.h>
-
-#endif
-
-typedef pthread_t ggml_thread_t;
-
-#ifdef GGML_USE_CPU_HBM
-#include <hbwmalloc.h>
-#endif
-
-#if defined(__APPLE__)
-#include <unistd.h>
-#include <mach/mach.h>
-#include <TargetConditionals.h>
-#endif
-
-//
-// cache line
-//
-
-#if defined(__cpp_lib_hardware_interference_size)
-#define CACHE_LINE_SIZE hardware_destructive_interference_size
-#else
-#if defined(__POWER9_VECTOR__)
-#define CACHE_LINE_SIZE 128
-#else
-#define CACHE_LINE_SIZE 64
-#endif
-#endif
-
-static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
-
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
-
-static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
- [GGML_TYPE_F32] = {
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
- .vec_dot_type = GGML_TYPE_F32,
- .nrows = 1,
- },
- [GGML_TYPE_F16] = {
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
- .vec_dot_type = GGML_TYPE_F16,
- .nrows = 1,
- },
- [GGML_TYPE_Q4_0] = {
- .vec_dot = ggml_vec_dot_q4_0_q8_0,
- .vec_dot_type = GGML_TYPE_Q8_0,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
- .nrows = 2,
-#else
- .nrows = 1,
-#endif
- },
- [GGML_TYPE_Q4_1] = {
- .vec_dot = ggml_vec_dot_q4_1_q8_1,
- .vec_dot_type = GGML_TYPE_Q8_1,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
- .nrows = 2,
-#else
- .nrows = 1,
-#endif
- },
- [4] = { // GGML_TYPE_Q4_2
- .vec_dot = NULL,
- .vec_dot_type = GGML_TYPE_COUNT,
- .nrows = 1,
- },
- [5] = { // GGML_TYPE_Q4_3
- .vec_dot = NULL,
- .vec_dot_type = GGML_TYPE_COUNT,
- .nrows = 1,
- },
- [GGML_TYPE_Q5_0] = {
- .vec_dot = ggml_vec_dot_q5_0_q8_0,
- .vec_dot_type = GGML_TYPE_Q8_0,
- .nrows = 1,
- },
- [GGML_TYPE_Q5_1] = {
- .vec_dot = ggml_vec_dot_q5_1_q8_1,
- .vec_dot_type = GGML_TYPE_Q8_1,
- .nrows = 1,
- },
- [GGML_TYPE_Q8_0] = {
- .from_float_to_mat = quantize_mat_q8_0,
- .vec_dot = ggml_vec_dot_q8_0_q8_0,
- .vec_dot_type = GGML_TYPE_Q8_0,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
- .nrows = 2,
-#else
- .nrows = 1,
-#endif
- },
- [GGML_TYPE_Q8_1] = {
- .vec_dot_type = GGML_TYPE_Q8_1,
- .nrows = 1,
- },
- [GGML_TYPE_Q2_K] = {
- .vec_dot = ggml_vec_dot_q2_K_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_Q3_K] = {
- .vec_dot = ggml_vec_dot_q3_K_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_Q4_K] = {
- .vec_dot = ggml_vec_dot_q4_K_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_Q5_K] = {
- .vec_dot = ggml_vec_dot_q5_K_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_Q6_K] = {
- .vec_dot = ggml_vec_dot_q6_K_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ2_XXS] = {
- .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ2_XS] = {
- .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ3_XXS] = {
- .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ3_S] = {
- .vec_dot = ggml_vec_dot_iq3_s_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ2_S] = {
- .vec_dot = ggml_vec_dot_iq2_s_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ1_S] = {
- .vec_dot = ggml_vec_dot_iq1_s_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ1_M] = {
- .vec_dot = ggml_vec_dot_iq1_m_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_IQ4_NL] = {
- .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
- .vec_dot_type = GGML_TYPE_Q8_0,
- .nrows = 1,
- },
- [GGML_TYPE_IQ4_XS] = {
- .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_BF16] = {
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
- .vec_dot_type = GGML_TYPE_BF16,
- .nrows = 1,
- },
- [GGML_TYPE_Q4_0_4_4] = {
- .vec_dot = NULL,
- .vec_dot_type = GGML_TYPE_Q8_0,
- .nrows = 1,
- .ncols = 4,
- .gemv = ggml_gemv_q4_0_4x4_q8_0,
- .gemm = ggml_gemm_q4_0_4x4_q8_0,
- },
- [GGML_TYPE_Q4_0_4_8] = {
- .vec_dot = NULL,
- .vec_dot_type = GGML_TYPE_Q8_0,
- .nrows = 1,
- .ncols = 4,
- .gemv = ggml_gemv_q4_0_4x8_q8_0,
- .gemm = ggml_gemm_q4_0_4x8_q8_0,
- },
- [GGML_TYPE_Q4_0_8_8] = {
- .vec_dot = NULL,
- .vec_dot_type = GGML_TYPE_Q8_0,
- .nrows = 1,
- .ncols = 8,
- .gemv = ggml_gemv_q4_0_8x8_q8_0,
- .gemm = ggml_gemm_q4_0_8x8_q8_0,
- },
- [GGML_TYPE_TQ1_0] = {
- .vec_dot = ggml_vec_dot_tq1_0_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
- [GGML_TYPE_TQ2_0] = {
- .vec_dot = ggml_vec_dot_tq2_0_q8_K,
- .vec_dot_type = GGML_TYPE_Q8_K,
- .nrows = 1,
- },
-};
-
-const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
- return &type_traits_cpu[type];
-}
-
-//
-// simd mappings
-//
-
-// we define a common set of C macros which map to specific intrinsics based on the current architecture
-// we then implement the fundamental computation operations below using only these macros
-// adding support for new architectures requires to define the corresponding SIMD macros
-//
-// GGML_F32_STEP / GGML_F16_STEP
-// number of elements to process in a single step
-//
-// GGML_F32_EPR / GGML_F16_EPR
-// number of elements to fit in a single register
-//
-
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
-
-#define GGML_SIMD
-
-// F32 NEON
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR 4
-
-#define GGML_F32x4 float32x4_t
-#define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
-#define GGML_F32x4_SET1(x) vdupq_n_f32(x)
-#define GGML_F32x4_LOAD vld1q_f32
-#define GGML_F32x4_STORE vst1q_f32
-#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
-#define GGML_F32x4_ADD vaddq_f32
-#define GGML_F32x4_MUL vmulq_f32
-#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#define GGML_F32x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
- } \
- (res) = GGML_F32x4_REDUCE_ONE((x)[0]); \
-}
-
-#define GGML_F32_VEC GGML_F32x4
-#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
- #define GGML_F16_STEP 32
- #define GGML_F16_EPR 8
-
- #define GGML_F16x8 float16x8_t
- #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
- #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
- #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
- #define GGML_F16x8_STORE vst1q_f16
- #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
- #define GGML_F16x8_ADD vaddq_f16
- #define GGML_F16x8_MUL vmulq_f16
- #define GGML_F16x8_REDUCE(res, x) \
- do { \
- int offset = GGML_F16_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
- } \
- const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
- const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
- (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
- } while (0)
-
- #define GGML_F16_VEC GGML_F16x8
- #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
- #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
- #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
- #define GGML_F16_VEC_FMA GGML_F16x8_FMA
- #define GGML_F16_VEC_ADD GGML_F16x8_ADD
- #define GGML_F16_VEC_MUL GGML_F16x8_MUL
- #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
-#else
- // if FP16 vector arithmetic is not supported, we use FP32 instead
- // and take advantage of the vcvt_ functions to convert to/from FP16
-
- #define GGML_F16_STEP 16
- #define GGML_F16_EPR 4
-
- #define GGML_F32Cx4 float32x4_t
- #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
- #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
- #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
- #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
- #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
- #define GGML_F32Cx4_ADD vaddq_f32
- #define GGML_F32Cx4_MUL vmulq_f32
- #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
-
- #define GGML_F16_VEC GGML_F32Cx4
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
-#endif
-
-#elif defined(__AVX512F__)
-
-#define GGML_SIMD
-
-// F32 AVX512
-
-#define GGML_F32_STEP 64
-#define GGML_F32_EPR 16
-
-#define GGML_F32x16 __m512
-#define GGML_F32x16_ZERO _mm512_setzero_ps()
-#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
-#define GGML_F32x16_LOAD _mm512_loadu_ps
-#define GGML_F32x16_STORE _mm512_storeu_ps
-// _mm512_fmadd_ps is defined in AVX512F so no guard is required
-#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32x16_ADD _mm512_add_ps
-#define GGML_F32x16_MUL _mm512_mul_ps
-#define GGML_F32x16_REDUCE(res, x) \
-do { \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- res = _mm512_reduce_add_ps(x[0]); \
-} while (0)
-
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC GGML_F32x16
-#define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x16_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x16_STORE
-#define GGML_F32_VEC_FMA GGML_F32x16_FMA
-#define GGML_F32_VEC_ADD GGML_F32x16_ADD
-#define GGML_F32_VEC_MUL GGML_F32x16_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
-
-// F16 AVX512
-
-// F16 AVX
-
-#define GGML_F16_STEP 64
-#define GGML_F16_EPR 16
-
-// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
-
-#define GGML_F32Cx16 __m512
-#define GGML_F32Cx16_ZERO _mm512_setzero_ps()
-#define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
-
-// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
-// so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
-#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
-
-#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32Cx16_ADD _mm512_add_ps
-#define GGML_F32Cx16_MUL _mm512_mul_ps
-#define GGML_F32Cx16_REDUCE(res, x) \
-do { \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm512_add_ps(x[i], x[offset+i]); \
- } \
- res = _mm512_reduce_add_ps(x[0]); \
-} while (0)
-
-#define GGML_F16_VEC GGML_F32Cx16
-#define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
-
-#elif defined(__AVX__)
-
-#define GGML_SIMD
-
-// F32 AVX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR 8
-
-#define GGML_F32x8 __m256
-#define GGML_F32x8_ZERO _mm256_setzero_ps()
-#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
-#define GGML_F32x8_LOAD _mm256_loadu_ps
-#define GGML_F32x8_STORE _mm256_storeu_ps
-#if defined(__FMA__)
- #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
-#else
- #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
-#endif
-#define GGML_F32x8_ADD _mm256_add_ps
-#define GGML_F32x8_MUL _mm256_mul_ps
-#define GGML_F32x8_REDUCE(res, x) \
-do { \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
- } \
- const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
- _mm256_extractf128_ps(x[0], 1)); \
- const __m128 t1 = _mm_hadd_ps(t0, t0); \
- res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC GGML_F32x8
-#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 AVX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR 8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8 __m256
-#define GGML_F32Cx8_ZERO _mm256_setzero_ps()
-#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
-
-#if defined(__F16C__)
-// the _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
-#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
-#else
-static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
- float tmp[8];
-
- for (int i = 0; i < 8; i++) {
- tmp[i] = GGML_FP16_TO_FP32(x[i]);
- }
-
- return _mm256_loadu_ps(tmp);
-}
-static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
- float arr[8];
-
- _mm256_storeu_ps(arr, y);
-
- for (int i = 0; i < 8; i++)
- x[i] = GGML_FP32_TO_FP16(arr[i]);
-}
-#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
-#endif
-
-#define GGML_F32Cx8_FMA GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD _mm256_add_ps
-#define GGML_F32Cx8_MUL _mm256_mul_ps
-#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC GGML_F32Cx8
-#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_SIMD
-
-// F32 POWER9
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR 4
-
-#define GGML_F32x4 vector float
-#define GGML_F32x4_ZERO 0.0f
-#define GGML_F32x4_SET1 vec_splats
-#define GGML_F32x4_LOAD(p) vec_xl(0, p)
-#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
-#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD vec_add
-#define GGML_F32x4_MUL vec_mul
-#define GGML_F32x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = vec_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = vec_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = vec_add(x[i], x[offset+i]); \
- } \
- res = vec_extract(x[0], 0) + \
- vec_extract(x[0], 1) + \
- vec_extract(x[0], 2) + \
- vec_extract(x[0], 3); \
-}
-
-#define GGML_F32_VEC GGML_F32x4
-#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 POWER9
-#define GGML_F16_STEP GGML_F32_STEP
-#define GGML_F16_EPR GGML_F32_EPR
-#define GGML_F16_VEC GGML_F32x4
-#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F16_VEC_FMA GGML_F32x4_FMA
-#define GGML_F16_VEC_ADD GGML_F32x4_ADD
-#define GGML_F16_VEC_MUL GGML_F32x4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
-// Use vec_xl, not vec_ld, in case the load address is not aligned.
-#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
- vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
- vec_extract_fp32_from_shortl(vec_xl(0, p))
-#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
-#define GGML_F16_VEC_STORE(p, r, i) \
- if (i & 0x1) \
- vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
- r[i - GGML_ENDIAN_BYTE(0)]), \
- 0, p - GGML_F16_EPR)
-
-#elif defined(__wasm_simd128__)
-
-#define GGML_SIMD
-
-// F32 WASM
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR 4
-
-#define GGML_F32x4 v128_t
-#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
-#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
-#define GGML_F32x4_LOAD wasm_v128_load
-#define GGML_F32x4_STORE wasm_v128_store
-#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
-#define GGML_F32x4_ADD wasm_f32x4_add
-#define GGML_F32x4_MUL wasm_f32x4_mul
-#define GGML_F32x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- res = wasm_f32x4_extract_lane(x[0], 0) + \
- wasm_f32x4_extract_lane(x[0], 1) + \
- wasm_f32x4_extract_lane(x[0], 2) + \
- wasm_f32x4_extract_lane(x[0], 3); \
-}
-
-#define GGML_F32_VEC GGML_F32x4
-#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 WASM
-
-#define GGML_F16_STEP 16
-#define GGML_F16_EPR 4
-
-inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
- float tmp[4];
-
- tmp[0] = GGML_FP16_TO_FP32(p[0]);
- tmp[1] = GGML_FP16_TO_FP32(p[1]);
- tmp[2] = GGML_FP16_TO_FP32(p[2]);
- tmp[3] = GGML_FP16_TO_FP32(p[3]);
-
- return wasm_v128_load(tmp);
-}
-
-inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
- float tmp[4];
-
- wasm_v128_store(tmp, x);
-
- p[0] = GGML_FP32_TO_FP16(tmp[0]);
- p[1] = GGML_FP32_TO_FP16(tmp[1]);
- p[2] = GGML_FP32_TO_FP16(tmp[2]);
- p[3] = GGML_FP32_TO_FP16(tmp[3]);
-}
-
-#define GGML_F16x4 v128_t
-#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
-#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
-#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
-#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
-#define GGML_F16x4_FMA GGML_F32x4_FMA
-#define GGML_F16x4_ADD wasm_f32x4_add
-#define GGML_F16x4_MUL wasm_f32x4_mul
-#define GGML_F16x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F16_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
- } \
- res = wasm_f32x4_extract_lane(x[0], 0) + \
- wasm_f32x4_extract_lane(x[0], 1) + \
- wasm_f32x4_extract_lane(x[0], 2) + \
- wasm_f32x4_extract_lane(x[0], 3); \
-}
-
-#define GGML_F16_VEC GGML_F16x4
-#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
-#define GGML_F16_VEC_SET1 GGML_F16x4_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F16x4_FMA
-#define GGML_F16_VEC_ADD GGML_F16x4_ADD
-#define GGML_F16_VEC_MUL GGML_F16x4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
-
-#elif defined(__SSE3__)
-
-#define GGML_SIMD
-
-// F32 SSE
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR 4
-
-#define GGML_F32x4 __m128
-#define GGML_F32x4_ZERO _mm_setzero_ps()
-#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
-#define GGML_F32x4_LOAD _mm_loadu_ps
-#define GGML_F32x4_STORE _mm_storeu_ps
-#if defined(__FMA__)
- // TODO: Does this work?
- #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
-#else
- #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
-#endif
-#define GGML_F32x4_ADD _mm_add_ps
-#define GGML_F32x4_MUL _mm_mul_ps
-#define GGML_F32x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
- } \
- const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
- res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
-}
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC GGML_F32x4
-#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 SSE
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR 4
-
-static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
- float tmp[4];
-
- tmp[0] = GGML_FP16_TO_FP32(x[0]);
- tmp[1] = GGML_FP16_TO_FP32(x[1]);
- tmp[2] = GGML_FP16_TO_FP32(x[2]);
- tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
- return _mm_loadu_ps(tmp);
-}
-
-static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
- float arr[4];
-
- _mm_storeu_ps(arr, y);
-
- x[0] = GGML_FP32_TO_FP16(arr[0]);
- x[1] = GGML_FP32_TO_FP16(arr[1]);
- x[2] = GGML_FP32_TO_FP16(arr[2]);
- x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4 __m128
-#define GGML_F32Cx4_ZERO _mm_setzero_ps()
-#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
-#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD _mm_add_ps
-#define GGML_F32Cx4_MUL _mm_mul_ps
-#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC GGML_F32Cx4
-#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
-
-#elif defined(__loongarch_asx)
-
-#define GGML_SIMD
-
-// F32 LASX
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR 8
-
-#define GGML_F32x8 __m256
-#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
-#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
-#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
-#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
-#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
-#define GGML_F32x8_ADD __lasx_xvfadd_s
-#define GGML_F32x8_MUL __lasx_xvfmul_s
-#define GGML_F32x8_REDUCE(res, x) \
-do { \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
- } \
- float *tmp_p = (float *)&x[0]; \
- res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC GGML_F32x8
-#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 LASX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR 8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8 __m256
-#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
-#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
-
-static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
- float tmp[8];
-
- for (int i = 0; i < 8; i++) {
- tmp[i] = GGML_FP16_TO_FP32(x[i]);
- }
-
- return (__m256)__lasx_xvld(tmp, 0);
-}
-static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
- float arr[8];
-
- __lasx_xvst(y, arr, 0);
-
- for (int i = 0; i < 8; i++) {
- x[i] = GGML_FP32_TO_FP16(arr[i]);
- }
-}
-#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
-
-#define GGML_F32Cx8_FMA GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD __lasx_xvfadd_s
-#define GGML_F32Cx8_MUL __lasx_xvfmul_s
-#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC GGML_F32Cx8
-#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
-
-#elif defined(__loongarch_sx)
-
-#define GGML_SIMD
-
-// F32 LSX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR 4
-
-#define GGML_F32x4 __m128
-#define GGML_F32x4_ZERO __lsx_vldi(0)
-#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
-#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
-#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
-#define GGML_F32x4_ADD __lsx_vfadd_s
-#define GGML_F32x4_MUL __lsx_vfmul_s
-#define GGML_F32x4_REDUCE(res, x) \
-{ \
- int offset = GGML_F32_ARR >> 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
- } \
- offset >>= 1; \
- for (int i = 0; i < offset; ++i) { \
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
- } \
- __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
- const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
- tmp = __lsx_vsrli_d((__m128i)t0, 32); \
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
- res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
-}
-
-#define GGML_F32_VEC GGML_F32x4
-#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 LSX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR 4
-
-static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
- float tmp[4];
-
- tmp[0] = GGML_FP16_TO_FP32(x[0]);
- tmp[1] = GGML_FP16_TO_FP32(x[1]);
- tmp[2] = GGML_FP16_TO_FP32(x[2]);
- tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
- return __lsx_vld(tmp, 0);
-}
-
-static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
- float arr[4];
-
- __lsx_vst(y, arr, 0);
-
- x[0] = GGML_FP32_TO_FP16(arr[0]);
- x[1] = GGML_FP32_TO_FP16(arr[1]);
- x[2] = GGML_FP32_TO_FP16(arr[2]);
- x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4 __m128
-#define GGML_F32Cx4_ZERO __lsx_vldi(0)
-#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD __lsx_vfadd_s
-#define GGML_F32Cx4_MUL __lsx_vfmul_s
-#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC GGML_F32Cx4
-#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
-
-#endif
-
-// GGML_F32_ARR / GGML_F16_ARR
-// number of registers to use per step
-#ifdef GGML_SIMD
-#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
-#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
-#endif
-
-//
-// Threading defs
-//
-
-typedef pthread_t ggml_thread_t;
-
-#if defined(_WIN32)
-
-typedef CONDITION_VARIABLE ggml_cond_t;
-typedef SRWLOCK ggml_mutex_t;
-
-#define ggml_mutex_init(m) InitializeSRWLock(m)
-#define ggml_mutex_destroy(m)
-#define ggml_mutex_lock(m) AcquireSRWLockExclusive(m)
-#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
-#define ggml_mutex_lock_shared(m) AcquireSRWLockShared(m)
-#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
-
-#define ggml_cond_init(c) InitializeConditionVariable(c)
-#define ggml_cond_destroy(c)
-#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
-#define ggml_cond_broadcast(c) WakeAllConditionVariable(c)
-
-#define ggml_thread_create pthread_create
-#define ggml_thread_join pthread_join
-
-#else
-
-typedef pthread_cond_t ggml_cond_t;
-typedef pthread_mutex_t ggml_mutex_t;
-
-#define ggml_mutex_init(m) pthread_mutex_init(m, NULL)
-#define ggml_mutex_destroy(m) pthread_mutex_destroy(m)
-#define ggml_mutex_lock(m) pthread_mutex_lock(m)
-#define ggml_mutex_unlock(m) pthread_mutex_unlock(m)
-#define ggml_mutex_lock_shared(m) pthread_mutex_lock(m)
-#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
-
-#define ggml_lock_init(x) UNUSED(x)
-#define ggml_lock_destroy(x) UNUSED(x)
-#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
-#define ggml_lock_lock(x) _mm_pause()
-#else
-#define ggml_lock_lock(x) UNUSED(x)
-#endif
-#define ggml_lock_unlock(x) UNUSED(x)
-
-#define GGML_LOCK_INITIALIZER 0
-#define ggml_cond_init(c) pthread_cond_init(c, NULL)
-#define ggml_cond_destroy(c) pthread_cond_destroy(c)
-#define ggml_cond_wait(c, m) pthread_cond_wait(c, m)
-#define ggml_cond_broadcast(c) pthread_cond_broadcast(c)
-
-#define ggml_thread_create pthread_create
-#define ggml_thread_join pthread_join
-
-#endif
-
-// Threadpool def
-struct ggml_threadpool {
- ggml_mutex_t mutex; // mutex for cond.var
- ggml_cond_t cond; // cond.var for waiting for new work
-
- struct ggml_cgraph * cgraph;
- struct ggml_cplan * cplan;
-
- // synchronization primitives
- atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
- atomic_int GGML_CACHE_ALIGN n_barrier;
- atomic_int GGML_CACHE_ALIGN n_barrier_passed;
- atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
-
- // these are atomic as an annotation for thread-sanitizer
- atomic_bool stop; // Used for stopping the threadpool altogether
- atomic_bool pause; // Used for pausing the threadpool or individual threads
- atomic_bool abort; // Used for aborting processing of a graph
-
- struct ggml_compute_state * workers; // per thread state
- int n_threads_max; // number of threads in the pool
- atomic_int n_threads_cur; // number of threads used in the current graph
-
- int32_t prio; // Scheduling priority
- uint32_t poll; // Polling level (0 - no polling)
-
- enum ggml_status ec;
-};
-
-// Per-thread state
-struct ggml_compute_state {
-#ifndef GGML_USE_OPENMP
- ggml_thread_t thrd;
- bool cpumask[GGML_MAX_N_THREADS];
- int last_graph;
- bool pending;
-#endif
- struct ggml_threadpool * threadpool;
- int ith;
-};
-
-struct ggml_compute_params {
- // ith = thread index, nth = number of threads
- int ith, nth;
-
- // work buffer for all threads
- size_t wsize;
- void * wdata;
-
- struct ggml_threadpool * threadpool;
-};
-
-//
-// fundamental operations
-//
-
-inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
-inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
-#if defined(GGML_SIMD)
- float sumf = 0.0f;
- const int np = (n & ~(GGML_F32_STEP - 1));
-
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
-
- GGML_F32_VEC ax[GGML_F32_ARR];
- GGML_F32_VEC ay[GGML_F32_ARR];
-
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
- sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
- }
- }
-
- // reduce sum0..sum3 to sum0
- GGML_F32_VEC_REDUCE(sumf, sum);
-
- // leftovers
- for (int i = np; i < n; ++i) {
- sumf += x[i]*y[i];
- }
-#else
- // scalar
- ggml_float sumf = 0.0;
- for (int i = 0; i < n; ++i) {
- sumf += (ggml_float)(x[i]*y[i]);
- }
-#endif
-
- *s = sumf;
-}
-
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
- int i = 0;
- ggml_float sumf = 0;
-
-#if defined(__AVX512BF16__)
- __m512 c1 = _mm512_setzero_ps();
- __m512 c2 = _mm512_setzero_ps();
- for (; i + 64 <= n; i += 64) {
- c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
- m512bh(_mm512_loadu_si512((y + i))));
- c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
- m512bh(_mm512_loadu_si512((y + i + 32))));
- }
- sumf += (ggml_float)_mm512_reduce_add_ps(c1);
- sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#elif defined(__AVX512F__)
-#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
- __m512 c1 = _mm512_setzero_ps();
- __m512 c2 = _mm512_setzero_ps();
- for (; i + 32 <= n; i += 32) {
- c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
- c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
- }
- sumf += (ggml_float)_mm512_reduce_add_ps(c1);
- sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#undef LOAD
-#elif defined(__AVX2__)
-#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
- __m256 c1 = _mm256_setzero_ps();
- __m256 c2 = _mm256_setzero_ps();
- __m256 c3 = _mm256_setzero_ps();
- __m256 c4 = _mm256_setzero_ps();
- for (; i + 32 <= n; i += 32) {
- c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
- c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
- c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
- c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
- }
- __m128 g;
- c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
- _mm256_add_ps(c2, c4));
- g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
- _mm256_castps256_ps128(c1));
- g = _mm_add_ps(g, _mm_movehl_ps(g, g));
- g = _mm_add_ss(g, _mm_movehdup_ps(g));
- sumf += (ggml_float)_mm_cvtss_f32(g);
-
-#undef LOAD
-#endif
-
- for (; i < n; ++i) {
- sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
- GGML_BF16_TO_FP32(y[i]));
- }
- *s = sumf;
-}
-
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
- assert(nrc == 1);
- UNUSED(nrc);
- UNUSED(bx);
- UNUSED(by);
- UNUSED(bs);
-
- ggml_float sumf = 0.0;
-
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
-
- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
-
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
-
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
- sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
- }
- }
-
- // reduce sum0..sum3 to sum0
- GGML_F16_VEC_REDUCE(sumf, sum);
-
- // leftovers
- for (int i = np; i < n; ++i) {
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
- }
-#else
- for (int i = 0; i < n; ++i) {
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
- }
-#endif
-
- *s = sumf;
-}
-
-// compute GGML_VEC_DOT_UNROLL dot products at once
-// xs - x row stride in bytes
-inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
-
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
-
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
- }
-
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
-
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
-
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
-
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
-
- sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
- }
- }
- }
-
- // reduce sum0..sum3 to sum0
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
- }
- }
-#else
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
- }
- }
-#endif
-
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
- s[i] = sumf[i];
- }
-}
-
-inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F32_STEP - 1));
-
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
- GGML_F32_VEC ax[GGML_F32_ARR];
- GGML_F32_VEC ay[GGML_F32_ARR];
-
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
-
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] += x[i]*v;
- }
-#else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] += x[i]*v;
- }
-#endif
-}
-
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
-
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
-
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
-
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
- }
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
- }
-#else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
- }
-#endif
-}
-
-// xs and vs are byte strides of x and v
-inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
-
- const float * restrict x[GGML_VEC_MAD_UNROLL];
- const float * restrict v[GGML_VEC_MAD_UNROLL];
-
- for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
- x[i] = (const float *) ((const char *) xv + i*xs);
- v[i] = (const float *) ((const char *) vv + i*vs);
- }
-
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F32_STEP - 1));
-
- GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
-
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- vx[k] = GGML_F32_VEC_SET1(v[k][0]);
- }
-
- GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
- GGML_F32_VEC ay[GGML_F32_ARR];
-
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
- }
-
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
-
- // leftovers
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- for (int i = np; i < n; ++i) {
- y[i] += x[k][i]*v[k][0];
- }
- }
-#else
- // scalar
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- for (int i = 0; i < n; ++i) {
- y[i] += x[k][i]*v[k][0];
- }
- }
-#endif
-}
-
-//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
-inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
-#if defined(GGML_USE_ACCELERATE)
- vDSP_vsmul(y, 1, &v, y, 1, n);
-#elif defined(GGML_SIMD)
- const int np = (n & ~(GGML_F32_STEP - 1));
-
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
- GGML_F32_VEC ay[GGML_F32_ARR];
-
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
-
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] *= v;
- }
-#else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] *= v;
- }
-#endif
-}
-
-inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
-
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
- GGML_F16_VEC ay[GGML_F16_ARR];
-
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
-
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
- }
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
- }
-#else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
- }
-#endif
-}
-
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
-inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
-inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
-inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
-inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
-inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
-inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
-inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
-inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
-inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
-inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
-// TODO: optimize performance
-inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
-
-static const float GELU_COEF_A = 0.044715f;
-static const float GELU_QUICK_COEF = -1.702f;
-static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-inline static float ggml_gelu_f32(float x) {
- return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- const uint16_t * i16 = (const uint16_t *) x;
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_table_gelu_f16[i16[i]];
- }
-}
-
-#ifdef GGML_GELU_FP16
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- if (x[i] <= -10.0f) {
- y[i] = 0.0f;
- } else if (x[i] >= 10.0f) {
- y[i] = x[i];
- } else {
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
- }
- }
-}
-#else
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_f32(x[i]);
- }
-}
-#endif
-
-inline static float ggml_gelu_quick_f32(float x) {
- return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
-}
-
-//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-// const uint16_t * i16 = (const uint16_t *) x;
-// for (int i = 0; i < n; ++i) {
-// y[i] = ggml_table_gelu_quick_f16[i16[i]];
-// }
-//}
-
-#ifdef GGML_GELU_QUICK_FP16
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
- }
-}
-#else
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_quick_f32(x[i]);
- }
-}
-#endif
-
-// Sigmoid Linear Unit (SiLU) function
-inline static float ggml_silu_f32(float x) {
- return x/(1.0f + expf(-x));
-}
-
-#if __FINITE_MATH_ONLY__
-#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
-#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
-#endif
-
-#if defined(__ARM_NEON) && defined(__aarch64__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static float32x4_t ggml_v_expf(float32x4_t x) {
- const float32x4_t r = vdupq_n_f32(0x1.8p23f);
- const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
- const float32x4_t n = vsubq_f32(z, r);
- const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
- vdupq_n_f32(0x1.7f7d1cp-20f));
- const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
- const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
- const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
- const float32x4_t u = vmulq_f32(b, b);
- const float32x4_t j = vfmaq_f32(
- vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
- vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
- vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
- if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
- return vfmaq_f32(k, j, k);
- const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
- const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
- const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
- return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
- vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static float32x4_t ggml_v_silu(float32x4_t x) {
- const float32x4_t one = vdupq_n_f32(1.0f);
- const float32x4_t zero = vdupq_n_f32(0.0f);
- const float32x4_t neg_x = vsubq_f32(zero, x);
- const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
- const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
- return vdivq_f32(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX512F__) && defined(__AVX512DQ__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m512 ggml_v_expf(__m512 x) {
- const __m512 r = _mm512_set1_ps(0x1.8p23f);
- const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
- const __m512 n = _mm512_sub_ps(z, r);
- const __m512 b =
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
- const __mmask16 d =
- _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
- const __m512 u = _mm512_mul_ps(b, b);
- const __m512 j = _mm512_fmadd_ps(
- _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
- _mm512_set1_ps(0x1.573e2ep-5f)),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
- _mm512_set1_ps(0x1.fffdb6p-2f))),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
- const __m512 res = _mm512_scalef_ps(j, n);
- if (_mm512_kortestz(d, d))
- return res;
- const __m512 zero = _mm512_setzero_ps();
- const __m512 alt = _mm512_mask_blend_ps(
- _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
- return _mm512_mask_blend_ps(d, res, alt);
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m512 ggml_v_silu(__m512 x) {
- const __m512 one = _mm512_set1_ps(1);
- const __m512 zero = _mm512_setzero_ps();
- const __m512 neg_x = _mm512_sub_ps(zero, x);
- const __m512 exp_neg_x = ggml_v_expf(neg_x);
- const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
- return _mm512_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX2__) && defined(__FMA__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m256 ggml_v_expf(__m256 x) {
- const __m256 r = _mm256_set1_ps(0x1.8p23f);
- const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
- const __m256 n = _mm256_sub_ps(z, r);
- const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
- _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
- const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
- const __m256 k = _mm256_castsi256_ps(
- _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
- const __m256i c = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(126), _CMP_GT_OQ));
- const __m256 u = _mm256_mul_ps(b, b);
- const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
- _mm256_set1_ps(0x1.573e2ep-5f)), u,
- _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
- _mm256_set1_ps(0x1.fffdb6p-2f))),
- u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
- if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
- return _mm256_fmadd_ps(j, k, k);
- const __m256i g = _mm256_and_si256(
- _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
- _mm256_set1_epi32(0x82000000u));
- const __m256 s1 =
- _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
- const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
- const __m256i d = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(192), _CMP_GT_OQ));
- return _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
- _mm256_andnot_ps(
- _mm256_castsi256_ps(d),
- _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(c),
- _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
- _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m256 ggml_v_silu(__m256 x) {
- const __m256 one = _mm256_set1_ps(1);
- const __m256 zero = _mm256_setzero_ps();
- const __m256 neg_x = _mm256_sub_ps(zero, x);
- const __m256 exp_neg_x = ggml_v_expf(neg_x);
- const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
- return _mm256_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
-
-#if defined(__FMA__)
-#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
-#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
-#else
-#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
-#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
-#endif
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m128 ggml_v_expf(__m128 x) {
- const __m128 r = _mm_set1_ps(0x1.8p23f);
- const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
- const __m128 n = _mm_sub_ps(z, r);
- const __m128 b =
- NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
- const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
- const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
- const __m128i c =
- _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
- const __m128 u = _mm_mul_ps(b, b);
- const __m128 j =
- MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
- MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
- u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
- if (!_mm_movemask_epi8(c))
- return MADD128(j, k, k);
- const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
- _mm_set1_epi32(0x82000000u));
- const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
- const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
- const __m128i d =
- _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
- return _mm_or_ps(
- _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
- _mm_andnot_ps(_mm_castsi128_ps(d),
- _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
- _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m128 ggml_v_silu(__m128 x) {
- const __m128 one = _mm_set1_ps(1);
- const __m128 zero = _mm_setzero_ps();
- const __m128 neg_x = _mm_sub_ps(zero, x);
- const __m128 exp_neg_x = ggml_v_expf(neg_x);
- const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
- return _mm_div_ps(x, one_plus_exp_neg_x);
-}
-
-#endif // __ARM_NEON / __AVX2__ / __SSE2__
-
-static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
- int i = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
- for (; i + 15 < n; i += 16) {
- _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
- }
-#elif defined(__AVX2__) && defined(__FMA__)
- for (; i + 7 < n; i += 8) {
- _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
- }
-#elif defined(__SSE2__)
- for (; i + 3 < n; i += 4) {
- _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
- }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- for (; i + 3 < n; i += 4) {
- vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
- }
-#endif
- for (; i < n; ++i) {
- y[i] = ggml_silu_f32(x[i]);
- }
-}
-
-static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
- int i = 0;
- ggml_float sum = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
- for (; i + 15 < n; i += 16) {
- __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
- _mm512_set1_ps(max)));
- _mm512_storeu_ps(y + i, val);
- sum += (ggml_float)_mm512_reduce_add_ps(val);
- }
-#elif defined(__AVX2__) && defined(__FMA__)
- for (; i + 7 < n; i += 8) {
- __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
- _mm256_set1_ps(max)));
- _mm256_storeu_ps(y + i, val);
- __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
- _mm256_castps256_ps128(val));
- val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
- val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
- sum += (ggml_float)_mm_cvtss_f32(val2);
- }
-#elif defined(__SSE2__)
- for (; i + 3 < n; i += 4) {
- __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
- _mm_set1_ps(max)));
- _mm_storeu_ps(y + i, val);
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
- val = _mm_add_ps(val, _mm_movehl_ps(val, val));
- val = _mm_add_ss(val, _mm_movehdup_ps(val));
-#else
- __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
- val = _mm_add_ps(val, tmp);
- tmp = _mm_movehl_ps(tmp, val);
- val = _mm_add_ss(val, tmp);
-#endif
- sum += (ggml_float)_mm_cvtss_f32(val);
- }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- for (; i + 3 < n; i += 4) {
- float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
- vdupq_n_f32(max)));
- vst1q_f32(y + i, val);
- sum += (ggml_float)vaddvq_f32(val);
- }
-#endif
- for (; i < n; ++i) {
- float val = expf(x[i] - max);
- sum += (ggml_float)val;
- y[i] = val;
- }
- return sum;
-}
-
-static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
- // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
-
- int i = 0;
- ggml_float sum = 0;
- for (; i < n; ++i) {
- float val = x[i] - max;
- y[i] = val;
- sum += (ggml_float)expf(val);
- }
- return sum = (ggml_float)logf(sum);
-}
-
-inline static float ggml_silu_backward_f32(float x, float dy) {
- const float s = 1.0f/(1.0f + expf(-x));
- return dy*s*(1.0f + x*(1.0f - s));
-}
-
-inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
- for (int i = 0; i < n; ++i) {
- dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
- }
-}
-
-inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
- ggml_float sum = 0.0;
- for (int i = 0; i < n; ++i) {
- sum += (ggml_float)x[i];
- }
- *s = sum;
-#else
- vDSP_sve(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
- ggml_float sum = 0.0;
- for (int i = 0; i < n; ++i) {
- sum += (ggml_float)x[i];
- }
- *s = sum;
-}
-
-inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
- float sum = 0.0f;
- for (int i = 0; i < n; ++i) {
- sum += GGML_FP16_TO_FP32(x[i]);
- }
- *s = sum;
-}
-
-inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
- float sum = 0.0f;
- for (int i = 0; i < n; ++i) {
- sum += GGML_BF16_TO_FP32(x[i]);
- }
- *s = sum;
-}
-
-inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
- float max = -INFINITY;
- for (int i = 0; i < n; ++i) {
- max = MAX(max, x[i]);
- }
- *s = max;
-#else
- vDSP_maxv(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
- ggml_vec_norm_f32(n, s, x);
- *s = 1.f/(*s);
-}
-
-inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
- float max = -INFINITY;
- int idx = 0;
- for (int i = 0; i < n; ++i) {
- max = MAX(max, x[i]);
- if (max == x[i]) { idx = i; }
- }
- *s = idx;
-}
-
-// Helpers for polling loops
-#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
-static inline void ggml_thread_cpu_relax(void) {
- __asm__ volatile("yield" ::: "memory");
-}
-#elif defined(__x86_64__)
-static inline void ggml_thread_cpu_relax(void) {
- _mm_pause();
-}
-#else
-static inline void ggml_thread_cpu_relax(void) {;}
-#endif
-
-//
-// NUMA support
-//
-
-#define GGML_NUMA_MAX_NODES 8
-#define GGML_NUMA_MAX_CPUS 512
-
-struct ggml_numa_node {
- uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
- uint32_t n_cpus;
-};
-
-struct ggml_numa_nodes {
- enum ggml_numa_strategy numa_strategy;
- struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
- uint32_t n_nodes;
- uint32_t total_cpus; // hardware threads on system
- uint32_t current_node; // node on which main process is execting
-#if defined(__gnu_linux__)
- cpu_set_t cpuset; // cpuset from numactl
-#else
- uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
-#endif
-};
-
-//
-// ggml state
-//
-
-struct ggml_state {
- struct ggml_numa_nodes numa;
-};
-
-// global state
-static struct ggml_state g_state = {0};
-static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
-
-// TODO: move to threading file
-// critical section via spin lock
-void ggml_critical_section_start(void) {
- while (atomic_flag_test_and_set(&g_state_critical)) {
- // spin
- sched_yield();
- }
-}
-
-void ggml_critical_section_end(void) {
- atomic_flag_clear(&g_state_critical);
-}
-
-static void ggml_barrier(struct ggml_threadpool * tp) {
- int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
- if (n_threads == 1) {
- return;
- }
-
-#ifdef GGML_USE_OPENMP
- #pragma omp barrier
-#else
- int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
-
- // enter barrier (full seq-cst fence)
- int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
-
- if (n_barrier == (n_threads - 1)) {
- // last thread
- atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
-
- // exit barrier (fill seq-cst fence)
- atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
- return;
- }
-
- // wait for other threads
- while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
- ggml_thread_cpu_relax();
- }
-
- // exit barrier (full seq-cst fence)
- // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
- #ifdef GGML_TSAN_ENABLED
- atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
- #else
- atomic_thread_fence(memory_order_seq_cst);
- #endif
-#endif
-}
-
-#if defined(__gnu_linux__)
-static cpu_set_t ggml_get_numa_affinity(void) {
- cpu_set_t cpuset;
- pthread_t thread;
- thread = pthread_self();
- CPU_ZERO(&cpuset);
- pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
- return cpuset;
-}
-#else
-static uint32_t ggml_get_numa_affinity(void) {
- return 0; // no NUMA support
-}
-#endif
-
-void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
- if (g_state.numa.n_nodes > 0) {
- fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
-
- return;
- }
-
-#if defined(__gnu_linux__)
- struct stat st;
- char path[256];
- int rv;
-
- // set numa scheme
- g_state.numa.numa_strategy = numa_flag;
-
- GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
-
- g_state.numa.cpuset = ggml_get_numa_affinity();
-
- // enumerate nodes
- while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
- rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
- GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
- if (stat(path, &st) != 0) { break; }
- ++g_state.numa.n_nodes;
- }
-
- // enumerate CPUs
- while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
- rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
- GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
- if (stat(path, &st) != 0) { break; }
- ++g_state.numa.total_cpus;
- }
-
- GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
-
- // figure out which node we're on
- uint current_cpu;
- int getcpu_ret = 0;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
- getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
-#else
- // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
-# if !defined(SYS_getcpu) && defined(SYS_get_cpu)
-# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
-# endif
- getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node);
-#endif
-
- if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
- g_state.numa.n_nodes = 0;
- return;
- }
-
- GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
-
- for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
- struct ggml_numa_node * node = &g_state.numa.nodes[n];
- GGML_PRINT_DEBUG("CPUs on node %u:", n);
- node->n_cpus = 0;
- for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
- rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
- GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
- if (stat(path, &st) == 0) {
- node->cpus[node->n_cpus++] = c;
- GGML_PRINT_DEBUG(" %u", c);
- }
- }
- GGML_PRINT_DEBUG("\n");
- }
-
- if (ggml_is_numa()) {
- FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
- if (fptr != NULL) {
- char buf[42];
- if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
- GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
- }
- fclose(fptr);
- }
- }
-#else
- UNUSED(numa_flag);
- // TODO
-#endif
-}
-
-bool ggml_is_numa(void) {
- return g_state.numa.n_nodes > 1;
-}
-
-#if defined(__ARM_ARCH)
-
-#if defined(__linux__) && defined(__aarch64__)
-#include <sys/auxv.h>
-#elif defined(__APPLE__)
-#include <sys/sysctl.h>
-#endif
-
-#if !defined(HWCAP2_I8MM)
-#define HWCAP2_I8MM 0
-#endif
-
-static void ggml_init_arm_arch_features(void) {
-#if defined(__linux__) && defined(__aarch64__)
- uint32_t hwcap = getauxval(AT_HWCAP);
- uint32_t hwcap2 = getauxval(AT_HWCAP2);
-
- ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
- ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
- ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
-
-#if defined(__ARM_FEATURE_SVE)
- ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
-#endif
-#elif defined(__APPLE__)
- int oldp = 0;
- size_t size = sizeof(oldp);
- if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
- oldp = 0;
- }
- ggml_arm_arch_features.has_neon = oldp;
-
- if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
- oldp = 0;
- }
- ggml_arm_arch_features.has_i8mm = oldp;
-
- ggml_arm_arch_features.has_sve = 0;
- ggml_arm_arch_features.sve_cnt = 0;
-#else
-// Run-time CPU feature detection not implemented for this platform, fallback to compile time
-#if defined(__ARM_NEON)
- ggml_arm_arch_features.has_neon = 1;
-#else
- ggml_arm_arch_features.has_neon = 0;
-#endif
-
-#if defined(__ARM_FEATURE_MATMUL_INT8)
- ggml_arm_arch_features.has_i8mm = 1;
-#else
- ggml_arm_arch_features.has_i8mm = 0;
-#endif
-
-#if defined(__ARM_FEATURE_SVE)
- ggml_arm_arch_features.has_sve = 1;
- ggml_arm_arch_features.sve_cnt = 16;
-#else
- ggml_arm_arch_features.has_sve = 0;
- ggml_arm_arch_features.sve_cnt = 0;
-#endif
-#endif
-}
-#endif
-
-struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
- GGML_ASSERT(!ggml_get_no_alloc(ctx));
-
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-
- ggml_set_i32(result, value);
-
- return result;
-}
-
-struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
- GGML_ASSERT(!ggml_get_no_alloc(ctx));
-
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
-
- ggml_set_f32(result, value);
-
- return result;
-}
-
-struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
- const int n = ggml_nrows(tensor);
- const int nc = tensor->ne[0];
- const size_t n1 = tensor->nb[1];
-
- char * const data = tensor->data;
-
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- assert(tensor->nb[0] == sizeof(int8_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_I16:
- {
- assert(tensor->nb[0] == sizeof(int16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_I32:
- {
- assert(tensor->nb[0] == sizeof(int32_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_F16:
- {
- assert(tensor->nb[0] == sizeof(ggml_fp16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
- }
- } break;
- case GGML_TYPE_BF16:
- {
- assert(tensor->nb[0] == sizeof(ggml_fp16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
- }
- } break;
- case GGML_TYPE_F32:
- {
- assert(tensor->nb[0] == sizeof(float));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
- }
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-
- return tensor;
-}
-
-struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
- const int n = ggml_nrows(tensor);
- const int nc = tensor->ne[0];
- const size_t n1 = tensor->nb[1];
-
- char * const data = tensor->data;
-
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- assert(tensor->nb[0] == sizeof(int8_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_I16:
- {
- assert(tensor->nb[0] == sizeof(int16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_I32:
- {
- assert(tensor->nb[0] == sizeof(int32_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
- }
- } break;
- case GGML_TYPE_F16:
- {
- assert(tensor->nb[0] == sizeof(ggml_fp16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
- }
- } break;
- case GGML_TYPE_BF16:
- {
- assert(tensor->nb[0] == sizeof(ggml_bf16_t));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
- }
- } break;
- case GGML_TYPE_F32:
- {
- assert(tensor->nb[0] == sizeof(float));
- for (int i = 0; i < n; i++) {
- ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
- }
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-
- return tensor;
-}
-
-int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
- if (!ggml_is_contiguous(tensor)) {
- int64_t id[4] = { 0, 0, 0, 0 };
- ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
- return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
- }
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
- return ((int8_t *)(tensor->data))[i];
- }
- case GGML_TYPE_I16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
- return ((int16_t *)(tensor->data))[i];
- }
- case GGML_TYPE_I32:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
- return ((int32_t *)(tensor->data))[i];
- }
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
- return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
- }
- case GGML_TYPE_BF16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
- return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
- }
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(float));
- return ((float *)(tensor->data))[i];
- }
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
- if (!ggml_is_contiguous(tensor)) {
- int64_t id[4] = { 0, 0, 0, 0 };
- ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
- ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
- return;
- }
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
- ((int8_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_I16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
- ((int16_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_I32:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
- ((int32_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
- ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
- } break;
- case GGML_TYPE_BF16:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
- ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
- } break;
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(tensor->nb[0] == sizeof(float));
- ((float *)(tensor->data))[i] = value;
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
- void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
- switch (tensor->type) {
- case GGML_TYPE_I8:
- return ((int8_t *) data)[0];
- case GGML_TYPE_I16:
- return ((int16_t *) data)[0];
- case GGML_TYPE_I32:
- return ((int32_t *) data)[0];
- case GGML_TYPE_F16:
- return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
- case GGML_TYPE_BF16:
- return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
- case GGML_TYPE_F32:
- return ((float *) data)[0];
- default:
- GGML_ABORT("fatal error");
- }
-}
-
-void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
- void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- ((int8_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_I16:
- {
- ((int16_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_I32:
- {
- ((int32_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_F16:
- {
- ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
- } break;
- case GGML_TYPE_BF16:
- {
- ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
- } break;
- case GGML_TYPE_F32:
- {
- ((float *)(data))[0] = value;
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
- if (!ggml_is_contiguous(tensor)) {
- int64_t id[4] = { 0, 0, 0, 0 };
- ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
- return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
- }
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- return ((int8_t *)(tensor->data))[i];
- }
- case GGML_TYPE_I16:
- {
- return ((int16_t *)(tensor->data))[i];
- }
- case GGML_TYPE_I32:
- {
- return ((int32_t *)(tensor->data))[i];
- }
- case GGML_TYPE_F16:
- {
- return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
- }
- case GGML_TYPE_BF16:
- {
- return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
- }
- case GGML_TYPE_F32:
- {
- return ((float *)(tensor->data))[i];
- }
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
- if (!ggml_is_contiguous(tensor)) {
- int64_t id[4] = { 0, 0, 0, 0 };
- ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
- ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
- return;
- }
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- ((int8_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_I16:
- {
- ((int16_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_I32:
- {
- ((int32_t *)(tensor->data))[i] = value;
- } break;
- case GGML_TYPE_F16:
- {
- ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
- } break;
- case GGML_TYPE_BF16:
- {
- ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
- } break;
- case GGML_TYPE_F32:
- {
- ((float *)(tensor->data))[i] = value;
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
- void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
- switch (tensor->type) {
- case GGML_TYPE_I8:
- return ((int8_t *) data)[0];
- case GGML_TYPE_I16:
- return ((int16_t *) data)[0];
- case GGML_TYPE_I32:
- return ((int32_t *) data)[0];
- case GGML_TYPE_F16:
- return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
- case GGML_TYPE_BF16:
- return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
- case GGML_TYPE_F32:
- return ((float *) data)[0];
- default:
- GGML_ABORT("fatal error");
- }
-}
-
-void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
- void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
- switch (tensor->type) {
- case GGML_TYPE_I8:
- {
- ((int8_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_I16:
- {
- ((int16_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_I32:
- {
- ((int32_t *)(data))[0] = value;
- } break;
- case GGML_TYPE_F16:
- {
- ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
- } break;
- case GGML_TYPE_BF16:
- {
- ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
- } break;
- case GGML_TYPE_F32:
- {
- ((float *)(data))[0] = value;
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// ggml_compute_forward_dup
-
-static void ggml_compute_forward_dup_same_cont(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
- GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
- GGML_ASSERT(src0->type == dst->type);
-
- const size_t nb0 = ggml_type_size(src0->type);
-
- const int ith = params->ith; // thread index
- const int nth = params->nth; // number of threads
-
- // parallelize by elements
- const int ne = ggml_nelements(dst);
- const int dr = (ne + nth - 1) / nth;
- const int ie0 = dr * ith;
- const int ie1 = MIN(ie0 + dr, ne);
-
- if (ie0 < ie1) {
- memcpy(
- ((char *) dst->data + ie0*nb0),
- ((char *) src0->data + ie0*nb0),
- (ie1 - ie0) * nb0);
- }
-}
-
-static void ggml_compute_forward_dup_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const int ith = params->ith; // thread index
- const int nth = params->nth; // number of threads
-
- // parallelize by rows
- const int nr = ne01;
- // number of rows per thread
- const int dr = (nr + nth - 1) / nth;
- // row range for this thread
- const int ir0 = dr * ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (src0->type == dst->type &&
- ne00 == ne0 &&
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
- // copy by rows
- const size_t rs = ne00*nb00;
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- memcpy(
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
- rs);
- }
- }
- }
- return;
- }
-
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
- if (ggml_is_contiguous(dst)) {
- if (nb00 == sizeof(ggml_fp16_t)) {
- if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- const size_t rs = ne00 * nb00;
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, rs);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- float * dst_ptr = (float *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- for (int i00 = 0; i00 < ne00; i00++) {
- dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (ggml_get_type_traits(dst->type)->from_float) {
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
- size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
- for (int i00 = 0; i00 < ne00; i00++) {
- src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
- }
-
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- } else {
- //printf("%s: this is not optimal - fix me\n", __func__);
-
- if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- float * dst_ptr = (float *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = *src0_ptr;
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- }
- return;
- }
-
- // dst counters
- int64_t i10 = 0;
- int64_t i11 = 0;
- int64_t i12 = 0;
- int64_t i13 = 0;
-
- if (dst->type == GGML_TYPE_F16) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
-
- if (++i10 == ne00) {
- i10 = 0;
- if (++i11 == ne01) {
- i11 = 0;
- if (++i12 == ne02) {
- i12 = 0;
- if (++i13 == ne03) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else if (dst->type == GGML_TYPE_F32) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
-}
-
-static void ggml_compute_forward_dup_bf16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const int ith = params->ith; // thread index
- const int nth = params->nth; // number of threads
-
- // parallelize by rows
- const int nr = ne01;
- // number of rows per thread
- const int dr = (nr + nth - 1) / nth;
- // row range for this thread
- const int ir0 = dr * ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (src0->type == dst->type &&
- ne00 == ne0 &&
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
- // copy by rows
- const size_t rs = ne00*nb00;
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- memcpy(
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
- rs);
- }
- }
- }
- return;
- }
-
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
- if (ggml_is_contiguous(dst)) {
- if (nb00 == sizeof(ggml_bf16_t)) {
- if (dst->type == GGML_TYPE_BF16) {
- size_t id = 0;
- const size_t rs = ne00 * nb00;
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, rs);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- for (int i00 = 0; i00 < ne00; i00++) {
- dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- float * dst_ptr = (float *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- for (int i00 = 0; i00 < ne00; i00++) {
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (ggml_get_type_traits(dst->type)->from_float) {
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
- size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
- for (int i00 = 0; i00 < ne00; i00++) {
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
- }
-
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- } else {
- //printf("%s: this is not optimal - fix me\n", __func__);
-
- if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- float * dst_ptr = (float *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_BF16) {
- size_t id = 0;
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = *src0_ptr;
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- }
- return;
- }
-
- // dst counters
- int64_t i10 = 0;
- int64_t i11 = 0;
- int64_t i12 = 0;
- int64_t i13 = 0;
-
- if (dst->type == GGML_TYPE_BF16) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
-
- if (++i10 == ne00) {
- i10 = 0;
- if (++i11 == ne01) {
- i11 = 0;
- if (++i12 == ne02) {
- i12 = 0;
- if (++i13 == ne03) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else if (dst->type == GGML_TYPE_F32) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
-}
-
-static void ggml_compute_forward_dup_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const int ith = params->ith; // thread index
- const int nth = params->nth; // number of threads
-
- // parallelize by rows
- const int nr = ne01;
- // number of rows per thread
- const int dr = (nr + nth - 1) / nth;
- // row range for this thread
- const int ir0 = dr * ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (src0->type == dst->type &&
- ne00 == ne0 &&
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
- // copy by rows
- const size_t rs = ne00*nb00;
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- memcpy(
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
- rs);
- }
- }
- }
- return;
- }
-
- if (ggml_is_contiguous(dst)) {
- // TODO: simplify
- if (nb00 == sizeof(float)) {
- if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- const size_t rs = ne00 * nb00;
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, rs);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else if (ggml_get_type_traits(dst->type)->from_float) {
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
-
- size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
- char * dst_ptr = (char *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- } else {
- //printf("%s: this is not optimal - fix me\n", __func__);
-
- if (dst->type == GGML_TYPE_F32) {
- size_t id = 0;
- float * dst_ptr = (float *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = *src0_ptr;
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (dst->type == GGML_TYPE_BF16) {
- size_t id = 0;
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- }
-
- return;
- }
-
- // dst counters
-
- int64_t i10 = 0;
- int64_t i11 = 0;
- int64_t i12 = 0;
- int64_t i13 = 0;
-
- if (dst->type == GGML_TYPE_F32) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- memcpy(dst_ptr, src0_ptr, sizeof(float));
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else if (dst->type == GGML_TYPE_F16) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else if (dst->type == GGML_TYPE_BF16) {
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- } else {
- GGML_ABORT("fatal error"); // TODO: implement
- }
-}
-
-// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
-static void ggml_compute_forward_dup_bytes(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
- GGML_ASSERT(src0->type == dst->type);
-
- GGML_TENSOR_UNARY_OP_LOCALS;
-
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
- ggml_compute_forward_dup_same_cont(params, dst);
- return;
- }
-
- const size_t type_size = ggml_type_size(src0->type);
- const int ith = params->ith; // thread index
- const int nth = params->nth; // number of threads
-
-
- // parallelize by rows
- const int nr = ne01;
- // number of rows per thread
- const int dr = (nr + nth - 1) / nth;
- // row range for this thread
- const int ir0 = dr * ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (src0->type == dst->type &&
- ne00 == ne0 &&
- nb00 == type_size && nb0 == type_size) {
- // copy by rows
- const size_t rs = ne00 * type_size;
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- memcpy(
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
- rs);
- }
- }
- }
- return;
- }
-
- if (ggml_is_contiguous(dst)) {
- size_t id = 0;
- char * dst_ptr = (char *) dst->data;
- const size_t rs = ne00 * type_size;
-
- if (nb00 == type_size) {
- // src0 is contigous on first dimension, copy by rows
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, rs);
- id += rs;
- }
- id += rs * (ne01 - ir1);
- }
- }
- } else {
- //printf("%s: this is not optimal - fix me\n", __func__);
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- id += rs * ir0;
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
- memcpy(dst_ptr + id, src0_ptr, type_size);
-
- id += type_size;
- }
- }
- id += rs * (ne01 - ir1);
- }
- }
- }
-
- return;
- }
-
- // dst counters
-
- int64_t i10 = 0;
- int64_t i11 = 0;
- int64_t i12 = 0;
- int64_t i13 = 0;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- i10 += ne00 * ir0;
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
-
- memcpy(dst_ptr, src0_ptr, type_size);
-
- if (++i10 == ne0) {
- i10 = 0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
- i10 += ne00 * (ne01 - ir1);
- while (i10 >= ne0) {
- i10 -= ne0;
- if (++i11 == ne1) {
- i11 = 0;
- if (++i12 == ne2) {
- i12 = 0;
- if (++i13 == ne3) {
- i13 = 0;
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_dup(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (src0->type == dst->type) {
- ggml_compute_forward_dup_bytes(params, dst);
- return;
- }
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_dup_f16(params, dst);
- } break;
- case GGML_TYPE_BF16:
- {
- ggml_compute_forward_dup_bf16(params, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_dup_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_add
-
-static void ggml_compute_forward_add_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(float)) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
- const int64_t nr0 = ne00 / ne10;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
- for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
- vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
- ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
- }
- }
- } else {
- // src1 is not contiguous
- for (int ir = ir0; ir < ir1; ++ir) {
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- const int64_t i10 = i0 % ne10;
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
- dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
- }
- }
- }
-}
-
-static void ggml_compute_forward_add_f16_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- if (dst->type == GGML_TYPE_F32) {
- GGML_ASSERT( nb0 == sizeof(float));
- }
- else {
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
- }
-
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(float)) {
- if (dst->type == GGML_TYPE_F16) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
- }
- }
- } else {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
- }
- }
- }
- }
- else {
- // src1 is not contiguous
- GGML_ABORT("fatal error");
- }
-}
-
-static void ggml_compute_forward_add_bf16_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_BF16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- if (dst->type == GGML_TYPE_F32) {
- GGML_ASSERT( nb0 == sizeof(float));
- }
- else {
- GGML_ASSERT(dst->type == GGML_TYPE_BF16);
- GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
- }
-
- GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(float)) {
- if (dst->type == GGML_TYPE_BF16) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
- }
- }
- } else {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
- }
- }
- }
- }
- else {
- // src1 is not contiguous
- GGML_ABORT("fatal error");
- }
-}
-
-static void ggml_compute_forward_add_f16_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F16);
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(ggml_fp16_t)) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
- }
- }
- }
- else {
- // src1 is not contiguous
- GGML_ABORT("fatal error");
- }
-}
-
-static void ggml_compute_forward_add_bf16_bf16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_BF16);
- GGML_ASSERT(src1->type == GGML_TYPE_BF16);
- GGML_ASSERT(dst->type == GGML_TYPE_BF16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(ggml_bf16_t)) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
- }
- }
- }
- else {
- // src1 is not contiguous
- GGML_ABORT("fatal error");
- }
-}
-
-static void ggml_compute_forward_add_q_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const enum ggml_type type = src0->type;
- const enum ggml_type dtype = dst->type;
- ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dtype)->from_float;
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == ggml_type_size(type));
- GGML_ASSERT(nb10 == sizeof(float));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- GGML_ASSERT(ggml_is_quantized(src0->type));
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- // src1 and dst are same shape as src0 => same indices
- const int i13 = i03;
- const int i12 = i02;
- const int i11 = i01;
-
- const int i3 = i03;
- const int i2 = i02;
- const int i1 = i01;
-
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
- float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
- void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
-
- assert(ne00 % 32 == 0);
-
- // unquantize row from src0 to temp buffer
- dequantize_row_q(src0_row, wdata, ne00);
- // add src1
- ggml_vec_acc_f32(ne00, wdata, src1_row);
- // quantize row to dst
- if (quantize_row_q != NULL) {
- quantize_row_q(wdata, dst_row, ne00);
- } else {
- memcpy(dst_row, wdata, ne0*nb0);
- }
- }
-}
-
-static void ggml_compute_forward_add(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- if (src1->type == GGML_TYPE_F32) {
- ggml_compute_forward_add_f32(params, dst);
- }
- else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_TYPE_F16:
- {
- if (src1->type == GGML_TYPE_F16) {
- ggml_compute_forward_add_f16_f16(params, dst);
- }
- else if (src1->type == GGML_TYPE_F32) {
- ggml_compute_forward_add_f16_f32(params, dst);
- }
- else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_TYPE_BF16:
- {
- if (src1->type == GGML_TYPE_BF16) {
- ggml_compute_forward_add_bf16_bf16(params, dst);
- }
- else if (src1->type == GGML_TYPE_F32) {
- ggml_compute_forward_add_bf16_f32(params, dst);
- }
- else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- {
- ggml_compute_forward_add_q_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_add1
-
-static void ggml_compute_forward_add1_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-#ifdef GGML_USE_ACCELERATE
- UNUSED(ggml_vec_add1_f32);
-
- vDSP_vadd(
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
- (float *) ((char *) src1->data), 0,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
- ne0);
-#else
- ggml_vec_add1_f32(ne0,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
- *(float *) src1->data);
-#endif
- }
-}
-
-static void ggml_compute_forward_add1_f16_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- // scalar to add
- const float v = *(float *) src1->data;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
- }
- }
-}
-
-static void ggml_compute_forward_add1_f16_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- // scalar to add
- const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F16);
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
- }
- }
-}
-
-static void ggml_compute_forward_add1_q_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- // scalar to add
- const float v = *(float *) src1->data;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const enum ggml_type type = src0->type;
- ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits(type)->from_float;
-
- // we don't support permuted src0
- GGML_ASSERT(nb00 == ggml_type_size(type));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- GGML_ASSERT(ggml_is_quantized(src0->type));
- GGML_ASSERT(dst->type == src0->type);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
- void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
-
- assert(ne0 % 32 == 0);
-
- // unquantize row from src0 to temp buffer
- dequantize_row_q(src0_row, wdata, ne0);
- // add src1
- ggml_vec_acc1_f32(ne0, wdata, v);
- // quantize row to dst
- quantize_row_q(wdata, dst_row, ne0);
- }
-}
-
-static void ggml_compute_forward_add1_bf16_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- // scalar to add
- const float v = *(float *) src1->data;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_BF16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_BF16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
- }
- }
-}
-
-static void ggml_compute_forward_add1_bf16_bf16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_scalar(src1));
-
- // scalar to add
- const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(src0->type == GGML_TYPE_BF16);
- GGML_ASSERT(src1->type == GGML_TYPE_BF16);
- GGML_ASSERT(dst->type == GGML_TYPE_BF16);
-
- GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
- for (int i = 0; i < ne0; i++) {
- dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
- }
- }
-}
-
-static void ggml_compute_forward_add1(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_add1_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- {
- if (src1->type == GGML_TYPE_F16) {
- ggml_compute_forward_add1_f16_f16(params, dst);
- }
- else if (src1->type == GGML_TYPE_F32) {
- ggml_compute_forward_add1_f16_f32(params, dst);
- }
- else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_TYPE_BF16:
- {
- if (src1->type == GGML_TYPE_BF16) {
- ggml_compute_forward_add1_bf16_bf16(params, dst);
- }
- else if (src1->type == GGML_TYPE_F32) {
- ggml_compute_forward_add1_bf16_f32(params, dst);
- }
- else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- {
- ggml_compute_forward_add1_q_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_acc
-
-static void ggml_compute_forward_acc_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
- // view src0 and dst with these strides and data offset inbytes during acc
- // nb0 is implicitly element_size because src0 and dst are contiguous
- size_t nb1 = ((int32_t *) dst->op_params)[0];
- size_t nb2 = ((int32_t *) dst->op_params)[1];
- size_t nb3 = ((int32_t *) dst->op_params)[2];
- size_t offset = ((int32_t *) dst->op_params)[3];
- bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- if (params->ith == 0) {
- // memcpy needs to be synchronized across threads to avoid race conditions.
- // => do it in INIT phase
- memcpy(
- ((char *) dst->data),
- ((char *) src0->data),
- ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
- }
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src1);
- const int nc = src1->ne[0];
-
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
-
- // src0 and dst as viewed during acc
- const size_t nb0 = ggml_element_size(src0);
-
- const size_t nb00 = nb0;
- const size_t nb01 = nb1;
- const size_t nb02 = nb2;
- const size_t nb03 = nb3;
-
- GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
- GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
-
- GGML_ASSERT(nb10 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are viewed with shape of src1 and offset
- // => same indices
- const int i3 = ir/(ne12*ne11);
- const int i2 = (ir - i3*ne12*ne11)/ne11;
- const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-#ifdef GGML_USE_ACCELERATE
- vDSP_vadd(
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
-#else
- ggml_vec_add_f32(nc,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-#endif
- }
-}
-
-static void ggml_compute_forward_acc(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_acc_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sub
-
-static void ggml_compute_forward_sub_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- if (nb10 == sizeof(float)) {
- for (int ir = ir0; ir < ir1; ++ir) {
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
- const int64_t nr0 = ne00 / ne10;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
- for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
- vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
- ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
- }
- }
- } else {
- // src1 is not contiguous
- for (int ir = ir0; ir < ir1; ++ir) {
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- const int64_t i10 = i0 % ne10;
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
- dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
- }
- }
- }
-}
-
-static void ggml_compute_forward_sub(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sub_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_mul
-
-static void ggml_compute_forward_mul_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- if (nb10 == sizeof(float)) {
- for (int64_t ir = ith; ir < nr; ir += nth) {
- // src0 and dst are same shape => same indices
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
- const int64_t nr0 = ne00 / ne10;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
- for (int64_t r = 0 ; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
- UNUSED(ggml_vec_mul_f32);
-
- vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
- ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
- }
- }
- } else {
- // src1 is not contiguous
- for (int64_t ir = ith; ir < nr; ir += nth) {
- // src0 and dst are same shape => same indices
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
- for (int64_t i0 = 0; i0 < ne00; ++i0) {
- const int64_t i10 = i0 % ne10;
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
- dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
- }
- }
- }
-}
-
-static void ggml_compute_forward_mul(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_mul_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_div
-
-static void ggml_compute_forward_div_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t nr = ggml_nrows(src0);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- if (nb10 == sizeof(float)) {
- for (int64_t ir = ith; ir < nr; ir += nth) {
- // src0 and dst are same shape => same indices
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
- const int64_t nr0 = ne00 / ne10;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
- for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
- UNUSED(ggml_vec_div_f32);
-
- vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
- ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
- }
- }
- } else {
- // src1 is not contiguous
- for (int64_t ir = ith; ir < nr; ir += nth) {
- // src0 and dst are same shape => same indices
- // src1 is broadcastable across src0 and dst in i1, i2, i3
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
- for (int64_t i0 = 0; i0 < ne00; ++i0) {
- const int64_t i10 = i0 % ne10;
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
- dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
- }
- }
- }
-}
-
-static void ggml_compute_forward_div(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_div_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sqr
-
-static void ggml_compute_forward_sqr_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- assert( dst->nb[0] == sizeof(float));
- assert(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_sqr_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_sqr(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sqr_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sqrt
-
-static void ggml_compute_forward_sqrt_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- assert( dst->nb[0] == sizeof(float));
- assert(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_sqrt_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_sqrt(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sqrt_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_log
-
-static void ggml_compute_forward_log_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- GGML_ASSERT( dst->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_log_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_log(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_log_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sin
-
-static void ggml_compute_forward_sin_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- GGML_ASSERT( dst->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_sin_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_sin(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sin_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_cos
-
-static void ggml_compute_forward_cos_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- GGML_ASSERT( dst->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_cos_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_cos(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_cos_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sum
-
-static void ggml_compute_forward_sum_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_scalar(dst));
- assert(src0->nb[0] == sizeof(float));
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
-
- ggml_float sum = 0;
- ggml_float row_sum = 0;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- ggml_vec_sum_f32_ggf(ne00,
- &row_sum,
- (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
- sum += row_sum;
- }
- }
- }
- ((float *) dst->data)[0] = sum;
-}
-
-static void ggml_compute_forward_sum_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_scalar(dst));
-
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
-
- float sum = 0;
- float row_sum = 0;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- ggml_vec_sum_f16_ggf(ne00,
- &row_sum,
- (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
- sum += row_sum;
- }
- }
- }
- ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
-}
-
-static void ggml_compute_forward_sum_bf16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_scalar(dst));
-
- assert(src0->nb[0] == sizeof(ggml_bf16_t));
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
- GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
-
- float sum = 0;
- float row_sum = 0;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- ggml_vec_sum_bf16_ggf(ne00,
- &row_sum,
- (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
- sum += row_sum;
- }
- }
- }
- ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
-}
-
-static void ggml_compute_forward_sum(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sum_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_sum_f16(params, dst);
- } break;
- case GGML_TYPE_BF16:
- {
- ggml_compute_forward_sum_bf16(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sum_rows
-
-static void ggml_compute_forward_sum_rows_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
- GGML_ASSERT(dst->nb[0] == sizeof(float));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(ne0 == 1);
- GGML_ASSERT(ne1 == ne01);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne3 == ne03);
-
- for (int64_t i3 = 0; i3 < ne03; i3++) {
- for (int64_t i2 = 0; i2 < ne02; i2++) {
- for (int64_t i1 = 0; i1 < ne01; i1++) {
- float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
- float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
- float row_sum = 0;
- ggml_vec_sum_f32(ne00, &row_sum, src_row);
- dst_row[0] = row_sum;
- }
- }
- }
-}
-
-static void ggml_compute_forward_sum_rows(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sum_rows_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_mean
-
-static void ggml_compute_forward_mean_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(src0->nb[0] == sizeof(float));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- assert(ne0 == 1);
- assert(ne1 == ne01);
- assert(ne2 == ne02);
- assert(ne3 == ne03);
-
- UNUSED(ne0);
- UNUSED(ne1);
- UNUSED(ne2);
- UNUSED(ne3);
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- ggml_vec_sum_f32(ne00,
- (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
- (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
-
- *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
- }
- }
- }
-}
-
-static void ggml_compute_forward_mean(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_mean_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_argmax
-
-static void ggml_compute_forward_argmax_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(src0->nb[0] == sizeof(float));
- assert(dst->nb[0] == sizeof(float));
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
-
- const size_t nb01 = src0->nb[1];
- const size_t nb0 = dst->nb[0];
-
- for (int64_t i1 = 0; i1 < ne01; i1++) {
- float * src = (float *) ((char *) src0->data + i1*nb01);
- int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
- int v = 0;
- ggml_vec_argmax_f32(ne00, &v, src);
- dst_[0] = v;
- }
-}
-
-static void ggml_compute_forward_argmax(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_argmax_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_count_equal
-
-static void ggml_compute_forward_count_equal_i32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- GGML_ASSERT(src0->type == GGML_TYPE_I32);
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(ggml_are_same_shape(src0, src1));
- GGML_ASSERT(ggml_is_scalar(dst));
- GGML_ASSERT(dst->type == GGML_TYPE_I64);
-
- const int64_t nr = ggml_nrows(src0);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- int64_t * sums = (int64_t *) params->wdata;
- int64_t sum_thread = 0;
-
- // rows per thread
- const int64_t dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int64_t ir0 = dr*ith;
- const int64_t ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t ir = ir0; ir < ir1; ++ir) {
- const int64_t i03 = ir / (ne02*ne01);
- const int64_t i02 = (ir - i03*ne03) / ne01;
- const int64_t i01 = ir - i03*ne03 - i02*ne02;
-
- const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
- const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
-
- for (int64_t i00 = 0; i00 < ne00; ++i00) {
- const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
- const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
-
- sum_thread += val0 == val1;
- }
- }
- if (ith != 0) {
- sums[ith] = sum_thread;
- }
- ggml_barrier(params->threadpool);
-
- if (ith != 0) {
- return;
- }
-
- for (int ith_other = 1; ith_other < nth; ++ith_other) {
- sum_thread += sums[ith_other];
- }
- *((int64_t *) dst->data) = sum_thread;
-}
-
-static void ggml_compute_forward_count_equal(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_I32:
- {
- ggml_compute_forward_count_equal_i32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_repeat
-
-static void ggml_compute_forward_repeat_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_can_repeat(src0, dst));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- // guaranteed to be an integer due to the check in ggml_can_repeat
- const int nr0 = (int)(ne0/ne00);
- const int nr1 = (int)(ne1/ne01);
- const int nr2 = (int)(ne2/ne02);
- const int nr3 = (int)(ne3/ne03);
-
- // TODO: support for transposed / permuted tensors
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- // TODO: maybe this is not optimal?
- for (int i3 = 0; i3 < nr3; i3++) {
- for (int k3 = 0; k3 < ne03; k3++) {
- for (int i2 = 0; i2 < nr2; i2++) {
- for (int k2 = 0; k2 < ne02; k2++) {
- for (int i1 = 0; i1 < nr1; i1++) {
- for (int k1 = 0; k1 < ne01; k1++) {
- for (int i0 = 0; i0 < nr0; i0++) {
- ggml_vec_cpy_f32(ne00,
- (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
- (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
- }
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_repeat_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_can_repeat(src0, dst));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- // guaranteed to be an integer due to the check in ggml_can_repeat
- const int nr0 = (int)(ne0/ne00);
- const int nr1 = (int)(ne1/ne01);
- const int nr2 = (int)(ne2/ne02);
- const int nr3 = (int)(ne3/ne03);
-
- // TODO: support for transposed / permuted tensors
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // TODO: maybe this is not optimal?
- for (int i3 = 0; i3 < nr3; i3++) {
- for (int k3 = 0; k3 < ne03; k3++) {
- for (int i2 = 0; i2 < nr2; i2++) {
- for (int k2 = 0; k2 < ne02; k2++) {
- for (int i1 = 0; i1 < nr1; i1++) {
- for (int k1 = 0; k1 < ne01; k1++) {
- for (int i0 = 0; i0 < nr0; i0++) {
- ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
- ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
- // ggml_vec_cpy_f16(ne00, y, x)
- for (int i = 0; i < ne00; ++i) {
- y[i] = x[i];
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_repeat(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- case GGML_TYPE_I16:
- {
- ggml_compute_forward_repeat_f16(params, dst);
- } break;
- case GGML_TYPE_F32:
- case GGML_TYPE_I32:
- {
- ggml_compute_forward_repeat_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_repeat_back
-
-static void ggml_compute_forward_repeat_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_can_repeat(dst, src0));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- // guaranteed to be an integer due to the check in ggml_can_repeat
- const int nr0 = (int)(ne00/ne0);
- const int nr1 = (int)(ne01/ne1);
- const int nr2 = (int)(ne02/ne2);
- const int nr3 = (int)(ne03/ne3);
-
- // TODO: support for transposed / permuted tensors
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- if (ggml_is_contiguous(dst)) {
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
- } else {
- for (int k3 = 0; k3 < ne3; k3++) {
- for (int k2 = 0; k2 < ne2; k2++) {
- for (int k1 = 0; k1 < ne1; k1++) {
- ggml_vec_set_f32(ne0,
- (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
- 0);
- }
- }
- }
- }
-
- // TODO: maybe this is not optimal?
- for (int i3 = 0; i3 < nr3; i3++) {
- for (int k3 = 0; k3 < ne3; k3++) {
- for (int i2 = 0; i2 < nr2; i2++) {
- for (int k2 = 0; k2 < ne2; k2++) {
- for (int i1 = 0; i1 < nr1; i1++) {
- for (int k1 = 0; k1 < ne1; k1++) {
- for (int i0 = 0; i0 < nr0; i0++) {
- ggml_vec_acc_f32(ne0,
- (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
- (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
- }
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_repeat_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_repeat_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_concat
-
-static void ggml_compute_forward_concat_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int32_t dim = ggml_get_op_params_i32(dst, 0);
-
- GGML_ASSERT(dim >= 0 && dim < 4);
-
- int64_t o[4] = {0, 0, 0, 0};
- o[dim] = src0->ne[dim];
-
- const float * x;
-
- // TODO: smarter multi-theading
- for (int i3 = 0; i3 < ne3; i3++) {
- for (int i2 = ith; i2 < ne2; i2 += nth) {
- for (int i1 = 0; i1 < ne1; i1++) {
- for (int i0 = 0; i0 < ne0; i0++) {
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
- x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
- } else {
- x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
- }
-
- float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
-
- *y = *x;
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_concat(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_I32:
- {
- ggml_compute_forward_concat_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_abs
-
-static void ggml_compute_forward_abs_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_abs_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_abs(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_abs_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sgn
-
-static void ggml_compute_forward_sgn_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_sgn_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_sgn(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sgn_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_neg
-
-static void ggml_compute_forward_neg_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_neg_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_neg(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_neg_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_step
-
-static void ggml_compute_forward_step_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_step_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_step(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_step_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_tanh
-
-static void ggml_compute_forward_tanh_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_tanh_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_tanh(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_tanh_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_elu
-
-static void ggml_compute_forward_elu_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_elu_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_elu(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_elu_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_relu
-
-static void ggml_compute_forward_relu_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_relu_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_relu(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_relu_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_sigmoid
-
-static void ggml_compute_forward_sigmoid_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_sigmoid_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_sigmoid(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_sigmoid_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_gelu
-
-static void ggml_compute_forward_gelu_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- ggml_vec_gelu_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
- for (int k = 0; k < nc; k++) {
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
- UNUSED(x);
- assert(!isnan(x));
- assert(!isinf(x));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_gelu(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_gelu_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_gelu_quick
-
-static void ggml_compute_forward_gelu_quick_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- ggml_vec_gelu_quick_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
- for (int k = 0; k < nc; k++) {
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
- UNUSED(x);
- assert(!isnan(x));
- assert(!isinf(x));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_gelu_quick(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_gelu_quick_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_silu
-
-static void ggml_compute_forward_silu_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- ggml_vec_silu_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
- for (int k = 0; k < nc; k++) {
- const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
- UNUSED(x);
- assert(!isnan(x));
- assert(!isinf(x));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_silu(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_silu_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-// ggml_compute_forward_leaky_relu
-
-static void ggml_compute_forward_leaky_relu_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- float negative_slope;
- memcpy(&negative_slope, dst->op_params, sizeof(float));
-
- assert(dst->nb[0] == sizeof(float));
- assert(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < n; i++) {
- ggml_vec_leaky_relu_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
- }
-}
-
-static void ggml_compute_forward_leaky_relu(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_leaky_relu_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_silu_back
-
-static void ggml_compute_forward_silu_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * grad = dst->src[1];
-
- assert(ggml_is_contiguous_1(grad));
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
- assert(ggml_are_same_shape(src0, grad));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- ggml_vec_silu_backward_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
- (float *) ((char *) grad->data + i1*(grad->nb[1])));
-
-#ifndef NDEBUG
- for (int k = 0; k < nc; k++) {
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
- UNUSED(x);
- assert(!isnan(x));
- assert(!isinf(x));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_silu_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_silu_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-
-static void ggml_compute_forward_hardswish_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_hardswish_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-static void ggml_compute_forward_hardswish(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_hardswish_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_hardsigmoid_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_hardsigmoid_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_hardsigmoid(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_hardsigmoid_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_exp_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- ggml_vec_exp_f32(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_exp(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_exp_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-
-// ggml_compute_forward_norm
-
-static void ggml_compute_forward_norm_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- GGML_ASSERT(eps > 0.0f);
-
- // TODO: optimize
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
- ggml_float sum = 0.0;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- sum += (ggml_float)x[i00];
- }
-
- float mean = sum/ne00;
-
- float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
- ggml_float sum2 = 0.0;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- float v = x[i00] - mean;
- y[i00] = v;
- sum2 += (ggml_float)(v*v);
- }
-
- float variance = sum2/ne00;
- const float scale = 1.0f/sqrtf(variance + eps);
-
- ggml_vec_scale_f32(ne00, y, scale);
- }
- }
- }
-}
-
-static void ggml_compute_forward_norm(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_norm_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_group_rms_norm
-
-static void ggml_compute_forward_rms_norm_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- GGML_ASSERT(eps > 0.0f);
-
- // TODO: optimize
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
- ggml_float sum = 0.0;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- sum += (ggml_float)(x[i00] * x[i00]);
- }
-
- const float mean = sum/ne00;
-
- float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
- memcpy(y, x, ne00 * sizeof(float));
- // for (int i00 = 0; i00 < ne00; i00++) {
- // y[i00] = x[i00];
- // }
-
- const float scale = 1.0f/sqrtf(mean + eps);
-
- ggml_vec_scale_f32(ne00, y, scale);
- }
- }
- }
-}
-
-static void ggml_compute_forward_rms_norm(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_rms_norm_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_rms_norm_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- // TODO: optimize
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
- // src1 is same shape as src0 => same indices
- const int64_t i11 = i01;
- const int64_t i12 = i02;
- const int64_t i13 = i03;
-
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
-
- ggml_float sum_xx = 0.0;
- ggml_float sum_xdz = 0.0;
-
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- sum_xx += (ggml_float)(x[i00] * x[i00]);
- sum_xdz += (ggml_float)(x[i00] * dz[i00]);
- }
-
- //const float mean = (float)(sum_xx)/ne00;
- const float mean_eps = (float)(sum_xx)/ne00 + eps;
- const float sum_eps = (float)(sum_xx) + eps*ne00;
- //const float mean_xdz = (float)(sum_xdz)/ne00;
- // we could cache rms from forward pass to improve performance.
- // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
- //const float rms = sqrtf(mean_eps);
- const float rrms = 1.0f / sqrtf(mean_eps);
- //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
-
- {
- // z = rms_norm(x)
- //
- // rms_norm(src0) =
- // scale(
- // src0,
- // div(
- // 1,
- // sqrt(
- // add(
- // scale(
- // sum(
- // sqr(
- // src0)),
- // (1.0/N)),
- // eps))));
-
- // postorder:
- // ## op args grad
- // 00 param src0 grad[#00]
- // 01 const 1
- // 02 sqr (#00) grad[#02]
- // 03 sum (#02) grad[#03]
- // 04 const 1/N
- // 05 scale (#03, #04) grad[#05]
- // 06 const eps
- // 07 add (#05, #06) grad[#07]
- // 08 sqrt (#07) grad[#08]
- // 09 div (#01,#08) grad[#09]
- // 10 scale (#00,#09) grad[#10]
- //
- // backward pass, given grad[#10]
- // #10: scale
- // grad[#00] += scale(grad[#10],#09)
- // grad[#09] += sum(mul(grad[#10],#00))
- // #09: div
- // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
- // #08: sqrt
- // grad[#07] += mul(grad[#08], div(0.5, #08))
- // #07: add
- // grad[#05] += grad[#07]
- // #05: scale
- // grad[#03] += scale(grad[#05],#04)
- // #03: sum
- // grad[#02] += repeat(grad[#03], #02)
- // #02:
- // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
- //
- // substitute and simplify:
- // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
- // grad[#02] = repeat(grad[#03], #02)
- // grad[#02] = repeat(scale(grad[#05],#04), #02)
- // grad[#02] = repeat(scale(grad[#07],#04), #02)
- // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
- // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
- // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
- // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
- // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
- // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
- // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
- // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
- // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
- // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
- // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
- // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
- // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
- // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
- // a = b*c + d*e
- // a = b*c*f/f + d*e*f/f
- // a = (b*c*f + d*e*f)*(1/f)
- // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
- // a = (b + d*e/c)*c
- // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
- // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
- // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
- // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
- // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
- // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
- // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
- // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
- // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
- // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
- }
- // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
- // post-order:
- // dx := x
- // dx := scale(dx,-mean_xdz/mean_eps)
- // dx := add(dx, dz)
- // dx := scale(dx, rrms)
- float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
- ggml_vec_cpy_f32 (ne00, dx, x);
- // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
- ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
- ggml_vec_acc_f32 (ne00, dx, dz);
- ggml_vec_scale_f32(ne00, dx, rrms);
- }
- }
- }
-}
-
-static void ggml_compute_forward_rms_norm_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_rms_norm_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_group_norm
-
-static void ggml_compute_forward_group_norm_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- // TODO: optimize
-
- float eps;
- memcpy(&eps, dst->op_params + 1, sizeof(float));
-
- int n_channels = src0->ne[2];
- int n_groups = dst->op_params[0];
- int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
- for (int i = ith; i < n_groups; i += nth) {
- int start = i * n_channels_per_group;
- int end = start + n_channels_per_group;
- if (end > n_channels) {
- end = n_channels;
- }
- int step = end - start;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- ggml_float sum = 0.0;
- for (int64_t i02 = start; i02 < end; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
- ggml_float sumr = 0.0;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- sumr += (ggml_float)x[i00];
- }
- sum += sumr;
- }
- }
- const float mean = sum / (ne00 * ne01 * step);
-
- ggml_float sum2 = 0.0;
- for (int64_t i02 = start; i02 < end; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
- float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
- ggml_float sumr = 0.0;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- float v = x[i00] - mean;
- y[i00] = v;
- sumr += (ggml_float)(v * v);
- }
- sum2 += sumr;
- }
- }
- const float variance = sum2 / (ne00 * ne01 * step);
- const float scale = 1.0f / sqrtf(variance + eps);
-
- for (int64_t i02 = start; i02 < end; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
- ggml_vec_scale_f32(ne00, y, scale);
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_group_norm(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_group_norm_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_mul_mat
-
-static void ggml_compute_forward_mul_mat_one_chunk(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const int64_t num_rows_per_vec_dot,
- const int64_t ir0_start,
- const int64_t ir0_end,
- const int64_t ir1_start,
- const int64_t ir1_end) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const enum ggml_type type = src0->type;
-
- const bool src1_cont = ggml_is_contiguous(src1);
-
- ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
- enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
-
- // broadcast factors
- const int64_t r2 = ne12 / ne02;
- const int64_t r3 = ne13 / ne03;
-
- //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
-
- // threads with no work simply yield (not sure if it helps)
- if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
- return;
- }
-
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
- assert(ne12 % ne02 == 0);
- assert(ne13 % ne03 == 0);
-
- // block-tiling attempt
- const int64_t blck_0 = 16;
- const int64_t blck_1 = 16;
-
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
-
- // attempt to reduce false-sharing (does not seem to make a difference)
- // 16 * 2, accounting for mmla kernels
- float tmp[32];
-
- for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
- for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
- const int64_t i13 = (ir1 / (ne12 * ne1));
- const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
- const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
-
- // broadcast src0 into src1
- const int64_t i03 = i13 / r3;
- const int64_t i02 = i12 / r2;
-
- const int64_t i1 = i11;
- const int64_t i2 = i12;
- const int64_t i3 = i13;
-
- const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
-
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
- // the original src1 data pointer, so we should index using the indices directly
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
- const char * src1_col = (const char*)wdata +
- (src1_cont || src1->type != vec_dot_type
- ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
- : (i11 * nb11 + i12 * nb12 + i13 * nb13));
- float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
-
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
- //}
-
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
- vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
- }
-
- for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
- memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_mul_mat(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const enum ggml_type type = src0->type;
-
- enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
- ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;
- ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat;
- int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
- int64_t const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
- ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
- ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
-
- GGML_ASSERT(ne0 == ne01);
- GGML_ASSERT(ne1 == ne11);
- GGML_ASSERT(ne2 == ne12);
- GGML_ASSERT(ne3 == ne13);
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == ggml_type_size(type));
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- // nb01 >= nb00 - src0 is not transposed
- // compute by src0 rows
-
-#if GGML_USE_LLAMAFILE
- // broadcast factors
- const int64_t r2 = ne12 / ne02;
- const int64_t r3 = ne13 / ne03;
-
- const bool src1_cont = ggml_is_contiguous(src1);
-
- if (src1_cont) {
- for (int64_t i13 = 0; i13 < ne13; i13++)
- for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
- (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
- nb01/ggml_type_size(src0->type),
- (const char *)src1->data + i12*nb12 + i13*nb13,
- nb11/ggml_type_size(src1->type),
- (char *)dst->data + i12*nb2 + i13*nb3,
- nb1/ggml_type_size(dst->type),
- ith, nth,
- src0->type,
- src1->type,
- dst->type))
- goto UseGgmlGemm1;
- return;
- }
-UseGgmlGemm1:;
-#endif
-
- if (src1->type != vec_dot_type) {
- char * wdata = params->wdata;
-
- const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
- const size_t nbw2 = nbw1*ne11;
- const size_t nbw3 = nbw2*ne12;
-
- assert(params->wsize >= ne13*nbw3);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- for (int64_t i13 = 0; i13 < ne13; ++i13) {
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
- int64_t i11_processed = 0;
- if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
- from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
- 4, ne10, blck_size_interleave);
- }
- i11_processed = ne11 - ne11 % 4;
- }
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
- from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
- ne10);
- }
- }
- }
- }
-
- if (ith == 0) {
- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
- atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed);
- }
-
- ggml_barrier(params->threadpool);
-
-#if GGML_USE_LLAMAFILE
- if (src1->type != vec_dot_type) {
- const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
- for (int64_t i13 = 0; i13 < ne13; i13++)
- for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
- (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
- nb01/ggml_type_size(src0->type),
- (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
- row_size/ggml_type_size(vec_dot_type),
- (char *)dst->data + i12*nb2 + i13*nb3,
- nb1/ggml_type_size(dst->type),
- ith, nth,
- src0->type,
- vec_dot_type,
- dst->type))
- goto UseGgmlGemm2;
- return;
- }
-UseGgmlGemm2:;
-#endif
-
- // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
- const int64_t nr0 = ne0;
-
- // This is the size of the rest of the dimensions of the result
- const int64_t nr1 = ne1 * ne2 * ne3;
-
- // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
- int64_t num_rows_per_vec_dot = vec_dot_num_rows;
- // TODO: currently the mmla kernels support only even numbered rows/cols.
- // this check can be removed once they are extended to support odd numbered rows/cols too
- if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
- num_rows_per_vec_dot = 1;
- }
-
- // Now select a reasonable chunk size.
- int chunk_size = 16;
-
- // We need to step up the size if it's small
- if (nr0 == 1 || nr1 == 1) {
- chunk_size = 64;
- }
-
- // distribute the work across the inner or outer loop based on which one is larger
- // The number of chunks in the 0/1 dim.
- // CEIL(nr0/chunk_size)
- int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
- int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
-
- // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
- // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
- // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
- if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
- // distribute the thread work across the inner or outer loop based on which one is larger
- nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
- nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
- }
-
- // The number of elements in each chunk
- const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
- const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
-
- if ((ggml_n_dims(src0) == 2) && gemv) {
- const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
- const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
- int64_t src0_start = (ith * ne01) / nth;
- int64_t src0_end = ((ith + 1) * ne01) / nth;
- src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
- src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
- if (src0_start >= src0_end) return;
-
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
- if (gemm && (ne11 > 3)) {
- gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
- }
- for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
- gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
- (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
- src0_end - src0_start);
- }
- return;
- }
-
- // The first chunk comes from our thread_id, the rest will get auto-assigned.
- int current_chunk = ith;
-
- while (current_chunk < nchunk0 * nchunk1) {
- const int64_t ith0 = current_chunk % nchunk0;
- const int64_t ith1 = current_chunk / nchunk0;
-
- const int64_t ir0_start = dr0 * ith0;
- const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
-
- const int64_t ir1_start = dr1 * ith1;
- const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
-
- ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
-
- if (nth >= nchunk0 * nchunk1) {
- break;
- }
-
- current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed);
- }
-}
-
-// ggml_compute_forward_mul_mat_id
-
-static void ggml_compute_forward_mul_mat_id(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * ids = dst->src[2];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const enum ggml_type type = src0->type;
-
- const bool src1_cont = ggml_is_contiguous(src1);
-
- ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
- enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
- ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
- ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == ggml_type_size(type));
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- // row groups
- const int n_ids = ids->ne[0]; // n_expert_used
- const int n_as = ne02; // n_expert
-
- char * wdata_src1_end = (src1->type == vec_dot_type) ?
- (char *) params->wdata :
- (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
-
- struct mmid_row_mapping {
- int32_t i1;
- int32_t i2;
- };
-
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
-
- if (src1->type != vec_dot_type) {
- char * wdata = params->wdata;
-
- const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
- const size_t nbw2 = nbw1*ne11;
- const size_t nbw3 = nbw2*ne12;
-
- assert(params->wsize >= ne13*nbw3);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- for (int64_t i13 = 0; i13 < ne13; ++i13) {
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
- for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
- from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
- ne10);
- }
- }
- }
- }
-
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
- if (ith == 0) {
- // initialize matrix_row_counts
- memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
-
- // group rows by src0 matrix
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
- for (int id = 0; id < n_ids; ++id) {
- const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
-
- assert(i02 >= 0 && i02 < n_as);
-
- MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
- matrix_row_counts[i02] += 1;
- }
- }
- }
-
- ggml_barrier(params->threadpool);
-
- // compute each matrix multiplication in sequence
- for (int cur_a = 0; cur_a < n_as; ++cur_a) {
- const int64_t cne1 = matrix_row_counts[cur_a];
-
- if (cne1 == 0) {
- continue;
- }
-
- const char * src0_cur = (const char *) src0->data + cur_a*nb02;
-
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
- const int64_t nr0 = ne01; // src0 rows
- const int64_t nr1 = cne1; // src1 rows
-
- if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
- int64_t src0_cur_start = (ith * ne01) / nth;
- int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
- src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
- src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
- if (src0_cur_start >= src0_cur_end) return;
-
- for (int ir1 = 0; ir1 < nr1; ir1++) {
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
- const int id = row_mapping.i1; // selected expert index
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = row_mapping.i2; // row index in src1
-
- const int64_t i1 = id; // selected expert index
- const int64_t i2 = i12; // row
-
- const char * src1_col = (const char *) wdata +
- (src1_cont || src1->type != vec_dot_type
- ? (i11 + i12 * ne11) * row_size
- : (i11 * nb11 + i12 * nb12));
-
- gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
- (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
- }
- continue;
- }
-
- // distribute the thread work across the inner or outer loop based on which one is larger
-
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-
- const int64_t ith0 = ith % nth0;
- const int64_t ith1 = ith / nth0;
-
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
-
- const int64_t ir010 = dr0*ith0;
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
-
- const int64_t ir110 = dr1*ith1;
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
-
- // threads with no work simply yield (not sure if it helps)
- //if (ir010 >= ir011 || ir110 >= ir111) {
- // sched_yield();
- // continue;
- //}
-
- // block-tiling attempt
- const int64_t blck_0 = 16;
- const int64_t blck_1 = 16;
-
- // attempt to reduce false-sharing (does not seem to make a difference)
- float tmp[16];
-
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
- const int64_t _i12 = ir1; // logical row index for this expert
-
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
- const int id = row_mapping.i1; // selected expert index
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = row_mapping.i2; // row index in src1
-
- const int64_t i1 = id; // selected expert index
- const int64_t i2 = i12; // row
-
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
- // the original src1 data pointer, so we should index using the indices directly
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
- const char * src1_col = (const char *) wdata +
- (src1_cont || src1->type != vec_dot_type
- ? (i11 + i12*ne11)*row_size
- : (i11*nb11 + i12*nb12));
-
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
-
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
- //}
-
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
- }
-
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
- }
- }
- }
- }
-
-#undef MMID_MATRIX_ROW
-}
-
-// ggml_compute_forward_out_prod
-
-static void ggml_compute_forward_out_prod_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_ASSERT(ne0 == ne00);
- GGML_ASSERT(ne1 == ne10);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne02 == ne12);
- GGML_ASSERT(ne3 == ne13);
- GGML_ASSERT(ne03 == ne13);
-
- // we don't support permuted src0 or src1
- GGML_ASSERT(nb00 == sizeof(float));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- // GGML_ASSERT(nb0 <= nb1);
- // GGML_ASSERT(nb1 <= nb2);
- // GGML_ASSERT(nb2 <= nb3);
-
- // nb01 >= nb00 - src0 is not transposed
- // compute by src0 rows
-
- if (ith == 0) {
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
- }
- ggml_barrier(params->threadpool);
-
- // dst[:,:,:,:] = 0
- // for i2,i3:
- // for i1:
- // for i01:
- // for i0:
- // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
- // parallelize by last three dimensions
-
- // total rows in dst
- const int64_t nr = ne1*ne2*ne3;
-
- // rows per thread
- const int64_t dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int64_t ir0 = dr*ith;
- const int64_t ir1 = MIN(ir0 + dr, nr);
-
- // block-tiling attempt
- const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
- const int64_t blck_1 = 16;
-
- for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
- const int64_t bir1 = MIN(bir + blck_1, ir1);
- for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
- const int64_t bne01 = MIN(bi01 + blck_0, ne01);
- for (int64_t ir = bir; ir < bir1; ++ir) {
- // dst indices
- const int64_t i3 = ir/(ne2*ne1);
- const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
- const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- const int64_t i02 = i2;
- const int64_t i03 = i3;
-
- //const int64_t i10 = i1;
- const int64_t i12 = i2;
- const int64_t i13 = i3;
-
-#if GGML_VEC_MAD_UNROLL > 2
- const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
- for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
- const int64_t i11 = i01;
-
- float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
- float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
-
- ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
- }
- for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
- const int64_t i11 = i01;
-
- float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
- float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
-
- ggml_vec_mad_f32(ne0, d, s0, *s1);
- }
-#else
- for (int64_t i01 = bi01; i01 < bne01; ++i01) {
- const int64_t i11 = i01;
-
- float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
- float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
-
- ggml_vec_mad_f32(ne0, d, s0, *s1);
- }
-#endif
- }
- }
- }
-}
-
-static void ggml_compute_forward_out_prod_q_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const enum ggml_type type = src0->type;
- ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
- GGML_ASSERT(ne02 == ne12);
- GGML_ASSERT(ne03 == ne13);
- GGML_ASSERT(ne2 == ne12);
- GGML_ASSERT(ne3 == ne13);
-
- // we don't support permuted src0 dim0
- GGML_ASSERT(nb00 == ggml_type_size(type));
-
- // dst dim0 cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- // GGML_ASSERT(nb0 <= nb1);
- // GGML_ASSERT(nb1 <= nb2);
- // GGML_ASSERT(nb2 <= nb3);
-
- GGML_ASSERT(ne0 == ne00);
- GGML_ASSERT(ne1 == ne10);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne3 == ne03);
-
- // nb01 >= nb00 - src0 is not transposed
- // compute by src0 rows
-
- if (ith == 0) {
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
- }
- ggml_barrier(params->threadpool);
-
- // parallelize by last three dimensions
-
- // total rows in dst
- const int64_t nr = ne1*ne2*ne3;
-
- // rows per thread
- const int64_t dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int64_t ir0 = dr*ith;
- const int64_t ir1 = MIN(ir0 + dr, nr);
-
- // dst[:,:,:,:] = 0
- // for i2,i3:
- // for i1:
- // for i01:
- // for i0:
- // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
- float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
- for (int64_t ir = ir0; ir < ir1; ++ir) {
- // dst indices
- const int64_t i3 = ir/(ne2*ne1);
- const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
- const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
- const int64_t i02 = i2;
- const int64_t i03 = i3;
-
- //const int64_t i10 = i1;
- const int64_t i12 = i2;
- const int64_t i13 = i3;
-
- for (int64_t i01 = 0; i01 < ne01; ++i01) {
- const int64_t i11 = i01;
-
- float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
- float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
-
- dequantize_row_q(s0, wdata, ne0);
- ggml_vec_mad_f32(ne0, d, wdata, *s1);
- }
- }
-}
-
-static void ggml_compute_forward_out_prod(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- {
- ggml_compute_forward_out_prod_q_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- {
- GGML_ABORT("fatal error"); // todo
- // ggml_compute_forward_out_prod_f16_f32(params, dst);
- }
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_out_prod_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_scale
-
-static void ggml_compute_forward_scale_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- // scale factor
- float v;
- memcpy(&v, dst->op_params, sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- const size_t nb01 = src0->nb[1];
-
- const size_t nb1 = dst->nb[1];
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- if (dst->data != src0->data) {
- // src0 is same shape as dst => same indices
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
- }
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
- }
-}
-
-static void ggml_compute_forward_scale(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_scale_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_set
-
-static void ggml_compute_forward_set_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
- // view src0 and dst with these strides and data offset inbytes during set
- // nb0 is implicitly element_size because src0 and dst are contiguous
- size_t nb1 = ((int32_t *) dst->op_params)[0];
- size_t nb2 = ((int32_t *) dst->op_params)[1];
- size_t nb3 = ((int32_t *) dst->op_params)[2];
- size_t offset = ((int32_t *) dst->op_params)[3];
- bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- if (params->ith == 0) {
- // memcpy needs to be synchronized across threads to avoid race conditions.
- // => do it in INIT phase
- memcpy(
- ((char *) dst->data),
- ((char *) src0->data),
- ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
- }
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src1);
- const int nc = src1->ne[0];
-
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
- GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
-
- // src0 and dst as viewed during set
- const size_t nb0 = ggml_element_size(src0);
-
- const int im0 = (ne10 == 0 ? 0 : ne10-1);
- const int im1 = (ne11 == 0 ? 0 : ne11-1);
- const int im2 = (ne12 == 0 ? 0 : ne12-1);
- const int im3 = (ne13 == 0 ? 0 : ne13-1);
-
- GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
-
- GGML_ASSERT(nb10 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 and dst are viewed with shape of src1 and offset
- // => same indices
- const int i3 = ir/(ne12*ne11);
- const int i2 = (ir - i3*ne12*ne11)/ne11;
- const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
- ggml_vec_cpy_f32(nc,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
- }
-}
-
-static void ggml_compute_forward_set(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_set_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_cpy
-
-static void ggml_compute_forward_cpy(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_cont
-
-static void ggml_compute_forward_cont(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_reshape
-
-static void ggml_compute_forward_reshape(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- // NOP
- UNUSED(params);
- UNUSED(dst);
-}
-
-// ggml_compute_forward_view
-
-static void ggml_compute_forward_view(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * dst) {
- // NOP
- UNUSED(params);
- UNUSED(dst);
-}
-
-// ggml_compute_forward_permute
-
-static void ggml_compute_forward_permute(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * dst) {
- // NOP
- UNUSED(params);
- UNUSED(dst);
-}
-
-// ggml_compute_forward_transpose
-
-static void ggml_compute_forward_transpose(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * dst) {
- // NOP
- UNUSED(params);
- UNUSED(dst);
-}
-
-// ggml_compute_forward_get_rows
-
-static void ggml_compute_forward_get_rows_q(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t nc = ne00;
- const int64_t nr = ggml_nelements(src1);
-
- const enum ggml_type type = src0->type;
- ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
- assert(ne0 == nc);
- assert(ne02 == ne11);
- assert(nb00 == ggml_type_size(type));
- assert(ggml_nrows(dst) == nr);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t i = ir0; i < ir1; ++i) {
- const int64_t i12 = i/(ne11*ne10);
- const int64_t i11 = (i - i12*ne11*ne10)/ne10;
- const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
- GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
- dequantize_row_q(
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
- }
-}
-
-static void ggml_compute_forward_get_rows_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t nc = ne00;
- const int64_t nr = ggml_nelements(src1);
-
- assert(ne0 == nc);
- assert(ne02 == ne11);
- assert(nb00 == sizeof(ggml_fp16_t));
- assert(ggml_nrows(dst) == nr);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t i = ir0; i < ir1; ++i) {
- const int64_t i12 = i/(ne11*ne10);
- const int64_t i11 = (i - i12*ne11*ne10)/ne10;
- const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
- GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
- ggml_fp16_to_fp32_row(
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
- }
-}
-
-static void ggml_compute_forward_get_rows_bf16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t nc = ne00;
- const int64_t nr = ggml_nelements(src1);
-
- assert(ne0 == nc);
- assert(ne02 == ne11);
- assert(nb00 == sizeof(ggml_bf16_t));
- assert(ggml_nrows(dst) == nr);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t i = ir0; i < ir1; ++i) {
- const int64_t i12 = i/(ne11*ne10);
- const int64_t i11 = (i - i12*ne11*ne10)/ne10;
- const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
- GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
- ggml_bf16_to_fp32_row(
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
- }
-}
-
-static void ggml_compute_forward_get_rows_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t nc = ne00;
- const int64_t nr = ggml_nelements(src1);
-
- assert(ne0 == nc);
- assert(ne02 == ne11);
- assert(nb00 == sizeof(float));
- assert(ggml_nrows(dst) == nr);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t i = ir0; i < ir1; ++i) {
- const int64_t i12 = i/(ne11*ne10);
- const int64_t i11 = (i - i12*ne11*ne10)/ne10;
- const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
- GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
- ggml_vec_cpy_f32(nc,
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
- (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
- }
-}
-
-static void ggml_compute_forward_get_rows(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- {
- ggml_compute_forward_get_rows_q(params, dst);
- } break;
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_get_rows_f16(params, dst);
- } break;
- case GGML_TYPE_BF16:
- {
- ggml_compute_forward_get_rows_bf16(params, dst);
- } break;
- case GGML_TYPE_F32:
- case GGML_TYPE_I32:
- {
- ggml_compute_forward_get_rows_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-
- //static bool first = true;
- //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
- //if (first) {
- // first = false;
- //} else {
- // for (int k = 0; k < dst->ne[1]; ++k) {
- // for (int j = 0; j < dst->ne[0]/16; ++j) {
- // for (int i = 0; i < 16; ++i) {
- // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
- // }
- // printf("\n");
- // }
- // printf("\n");
- // }
- // printf("\n");
- // exit(0);
- //}
-}
-
-// ggml_compute_forward_get_rows_back
-
-static void ggml_compute_forward_get_rows_back_f32_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_is_contiguous(dst));
-
- // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
- memset(dst->data, 0, ggml_nbytes(dst));
-
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
-
- GGML_ASSERT( dst->ne[0] == nc);
- GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
-
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
-
- for (int j = 0; j < nc; ++j) {
- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
- ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
- }
- }
-}
-
-static void ggml_compute_forward_get_rows_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- if (params->ith != 0) {
- return;
- }
-
- GGML_ASSERT(ggml_is_contiguous(dst));
-
- // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
- memset(dst->data, 0, ggml_nbytes(dst));
-
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
-
- GGML_ASSERT( dst->ne[0] == nc);
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
-
- ggml_vec_add_f32(nc,
- (float *) ((char *) dst->data + r*dst->nb[1]),
- (float *) ((char *) dst->data + r*dst->nb[1]),
- (float *) ((char *) src0->data + i*src0->nb[1]));
- }
-}
-
-static void ggml_compute_forward_get_rows_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_get_rows_back_f32_f16(params, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_get_rows_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-
- //static bool first = true;
- //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
- //if (first) {
- // first = false;
- //} else {
- // for (int k = 0; k < dst->ne[1]; ++k) {
- // for (int j = 0; j < dst->ne[0]/16; ++j) {
- // for (int i = 0; i < 16; ++i) {
- // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
- // }
- // printf("\n");
- // }
- // printf("\n");
- // }
- // printf("\n");
- // exit(0);
- //}
-}
-
-// ggml_compute_forward_diag
-
-static void ggml_compute_forward_diag_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- // TODO: handle transposed/permuted matrices
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(ne00 == ne0);
- GGML_ASSERT(ne00 == ne1);
- GGML_ASSERT(ne01 == 1);
- GGML_ASSERT(ne02 == ne2);
- GGML_ASSERT(ne03 == ne3);
-
- GGML_ASSERT(nb00 == sizeof(float));
- GGML_ASSERT(nb0 == sizeof(float));
-
- for (int i3 = 0; i3 < ne3; i3++) {
- for (int i2 = 0; i2 < ne2; i2++) {
- for (int i1 = 0; i1 < ne1; i1++) {
- float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
- float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
- for (int i0 = 0; i0 < i1; i0++) {
- d[i0] = 0;
- }
- d[i1] = s[i1];
- for (int i0 = i1+1; i0 < ne0; i0++) {
- d[i0] = 0;
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_diag(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_diag_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_diag_mask_inf
-
-static void ggml_compute_forward_diag_mask_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const float value) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int n_past = ((int32_t *) dst->op_params)[0];
- const bool inplace = src0->data == dst->data;
-
- GGML_ASSERT(n_past >= 0);
-
- if (!inplace) {
- if (ith == 0) {
- // memcpy needs to be synchronized across threads to avoid race conditions.
- // => do it in INIT phase
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
- GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
- memcpy(
- ((char *) dst->data),
- ((char *) src0->data),
- ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
- }
-
- // TODO: handle transposed/permuted matrices
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
- const int nr = src0->ne[1];
- const int nz = n/nr;
-
- GGML_ASSERT( dst->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- for (int k = 0; k < nz; k++) {
- for (int j = ith; j < nr; j += nth) {
- for (int i = n_past; i < nc; i++) {
- if (i > n_past + j) {
- *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_diag_mask_inf(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_diag_mask_zero(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_diag_mask_f32(params, dst, 0);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_soft_max
-
-static void ggml_compute_forward_soft_max_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- assert(ggml_is_contiguous(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- float scale = 1.0f;
- float max_bias = 0.0f;
-
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
-
- // TODO: handle transposed/permuted matrices
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
-
- // TODO: is this supposed to be ceil instead of floor?
- // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
- const uint32_t n_head = ne02;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- // ALiBi
- const uint32_t h = (i1/ne01)%ne02; // head
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
-
- // broadcast the mask across rows
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
- ggml_vec_cpy_f32 (nc, wp, sp);
- ggml_vec_scale_f32(nc, wp, scale);
- if (mp_f32) {
- if (use_f16) {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
- }
- } else {
- for (int i = 0; i < nc; ++i) {
- wp[i] += slope*mp_f32[i];
- }
- }
- }
-
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(wp[i]));
- }
-#endif
-
- float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, wp);
-
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
- assert(sum > 0.0);
-
- sum = 1.0/sum;
- ggml_vec_scale_f32(nc, dp, sum);
-
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- assert(!isnan(dp[i]));
- assert(!isinf(dp[i]));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_soft_max(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_soft_max_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-
-// ggml_compute_forward_soft_max_back
-
-static void ggml_compute_forward_soft_max_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
- GGML_ASSERT(ggml_are_same_shape(src1, dst));
-
- // TODO: handle transposed/permuted matrices
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
- float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
- float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
-
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(dy[i]));
- assert(!isnan(y[i]));
- }
-#endif
- // Jii = yi - yi*yi
- // Jij = -yi*yj
- // J = diag(y)-y.T*y
- // dx = J * dy
- // dxk = sum_i(Jki * dyi)
- // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
- // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
- // dxk = sum_i(-yk*yi * dyi) + yk*dyk
- // dxk = -yk * sum_i(yi * dyi) + yk*dyk
- // dxk = -yk * dot(y, dy) + yk*dyk
- // dxk = yk * (- dot(y, dy) + dyk)
- // dxk = yk * (dyk - dot(y, dy))
- //
- // post-order:
- // dot_y_dy := dot(y, dy)
- // dx := dy
- // dx := dx - dot_y_dy
- // dx := dx * y
-
- // linear runtime, no additional memory
- float dot_y_dy = 0;
- ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
- ggml_vec_cpy_f32 (nc, dx, dy);
- ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
- ggml_vec_mul_f32 (nc, dx, dx, y);
-
-#ifndef NDEBUG
- for (int i = 0; i < nc; ++i) {
- assert(!isnan(dx[i]));
- assert(!isinf(dx[i]));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_soft_max_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_soft_max_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_clamp
-
-static void ggml_compute_forward_clamp_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- float min;
- float max;
- memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- const size_t nb00 = src0->nb[0];
- const size_t nb01 = src0->nb[1];
-
- const size_t nb0 = dst->nb[0];
- const size_t nb1 = dst->nb[1];
-
- GGML_ASSERT( nb0 == sizeof(float));
- GGML_ASSERT(nb00 == sizeof(float));
-
- for (int j = ith; j < n; j += nth) {
- float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
- float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
-
- for (int i = 0; i < nc; i++) {
- dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
- }
- }
-}
-
-static void ggml_compute_forward_clamp(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_clamp_f32(params, dst);
- } break;
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_TQ1_0:
- case GGML_TYPE_TQ2_0:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_Q8_K:
- case GGML_TYPE_Q4_0_4_4:
- case GGML_TYPE_Q4_0_4_8:
- case GGML_TYPE_Q4_0_8_8:
- case GGML_TYPE_I8:
- case GGML_TYPE_I16:
- case GGML_TYPE_I32:
- case GGML_TYPE_I64:
- case GGML_TYPE_F64:
- case GGML_TYPE_COUNT:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_rope
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
- const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
- return 1 - MIN(1, MAX(0, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
- float * cos_theta, float * sin_theta) {
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = freq_scale * theta_extrap;
- float theta = theta_interp;
- if (ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
- }
- *cos_theta = cosf(theta) * mscale;
- *sin_theta = sinf(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
- return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
-}
-
-static void ggml_rope_cache_init(
- float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
- float * cache, float sin_sign, float theta_scale) {
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
- float theta = theta_base;
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
- rope_yarn(
- theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
- );
- cache[i0 + 1] *= sin_sign;
-
- theta *= theta_scale;
- }
-}
-
-void ggml_rope_yarn_corr_dims(
- int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
-) {
- // start and end correction dims
- float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
- float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
- dims[0] = MAX(0, start);
- dims[1] = MIN(n_dims - 1, end);
-}
-
-static void ggml_compute_forward_rope_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const bool forward) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * src2 = dst->src[2];
-
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
- GGML_ASSERT(nb00 == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(dst);
-
- GGML_ASSERT(n_dims <= ne0);
- GGML_ASSERT(n_dims % 2 == 0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- // row index used to determine which thread to use
- int ir = 0;
-
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
- float corr_dims[2];
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
- const float * freq_factors = NULL;
- if (src2 != NULL) {
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
- freq_factors = (const float *) src2->data;
- }
-
- // backward process uses inverse rotation by cos and sin.
- // cos and sin build a rotation matrix, where the inverse is the transpose.
- // this essentially just switches the sign of sin.
- const float sin_sign = forward ? 1.0f : -1.0f;
-
- const int32_t * pos = (const int32_t *) src1->data;
-
- for (int64_t i3 = 0; i3 < ne3; i3++) {
- for (int64_t i2 = 0; i2 < ne2; i2++) {
- const int64_t p = pos[i2];
-
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-
- for (int64_t i1 = 0; i1 < ne1; i1++) {
- if (ir++ < ir0) continue;
- if (ir > ir1) break;
-
- if (!is_neox) {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
- const float cos_theta = cache[i0 + 0];
- const float sin_theta = cache[i0 + 1];
-
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- const float x0 = src[0];
- const float x1 = src[1];
-
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[1] = x0*sin_theta + x1*cos_theta;
- }
- } else {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
- const int64_t ic = i0/2;
-
- const float cos_theta = cache[i0 + 0];
- const float sin_theta = cache[i0 + 1];
-
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
-
- const float x0 = src[0];
- const float x1 = src[n_dims/2];
-
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
- }
- }
-
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- }
-}
-
-// TODO: deduplicate f16/f32 code
-static void ggml_compute_forward_rope_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const bool forward) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * src2 = dst->src[2];
-
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(dst);
-
- GGML_ASSERT(n_dims <= ne0);
- GGML_ASSERT(n_dims % 2 == 0);
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- // row index used to determine which thread to use
- int ir = 0;
-
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
- float corr_dims[2];
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
- const float * freq_factors = NULL;
- if (src2 != NULL) {
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
- freq_factors = (const float *) src2->data;
- }
-
- // backward process uses inverse rotation by cos and sin.
- // cos and sin build a rotation matrix, where the inverse is the transpose.
- // this essentially just switches the sign of sin.
- const float sin_sign = forward ? 1.0f : -1.0f;
-
- const int32_t * pos = (const int32_t *) src1->data;
-
- for (int64_t i3 = 0; i3 < ne3; i3++) {
- for (int64_t i2 = 0; i2 < ne2; i2++) {
- const int64_t p = pos[i2];
-
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-
- for (int64_t i1 = 0; i1 < ne1; i1++) {
- if (ir++ < ir0) continue;
- if (ir > ir1) break;
-
- if (!is_neox) {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
- const float cos_theta = cache[i0 + 0];
- const float sin_theta = cache[i0 + 1];
-
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- const float x0 = GGML_FP16_TO_FP32(src[0]);
- const float x1 = GGML_FP16_TO_FP32(src[1]);
-
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
- }
- } else {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
- const int64_t ic = i0/2;
-
- const float cos_theta = cache[i0 + 0];
- const float sin_theta = cache[i0 + 1];
-
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
-
- const float x0 = GGML_FP16_TO_FP32(src[0]);
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
-
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
- }
- }
-
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_rope(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_rope_f16(params, dst, true);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_rope_f32(params, dst, true);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_rope_back
-
-static void ggml_compute_forward_rope_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_rope_f16(params, dst, false);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_rope_f32(params, dst, false);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_conv_transpose_1d
-
-static void ggml_compute_forward_conv_transpose_1d_f16_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nk = ne00*ne01*ne02;
-
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb10 == sizeof(float));
-
- if (ith == 0) {
- memset(params->wdata, 0, params->wsize);
-
- // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
- {
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
- ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- dst_data[i00*ne02 + i02] = src[i00];
- }
- }
- }
- }
-
- // permute source data (src1) from (L x Cin) to (Cin x L)
- {
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
- ggml_fp16_t * dst_data = wdata;
-
- for (int64_t i11 = 0; i11 < ne11; i11++) {
- const float * const src = (float *)((char *) src1->data + i11*nb11);
- for (int64_t i10 = 0; i10 < ne10; i10++) {
- dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
- }
- }
- }
-
- // need to zero dst since we are accumulating into it
- memset(dst->data, 0, ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
-
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
- // total rows in dst
- const int nr = ne1;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
- ggml_fp16_t * const wdata_src = wdata + nk;
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
- ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
- for (int i10 = 0; i10 < ne10; i10++) {
- const int i1n = i10*ne11;
- for (int i00 = 0; i00 < ne00; i00++) {
- float v = 0;
- ggml_vec_dot_f16(ne02, &v, 0,
- (ggml_fp16_t *) wdata_src + i1n, 0,
- (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
- dst_data[i10*s0 + i00] += v;
- }
- }
- }
-}
-
-static void ggml_compute_forward_conv_transpose_1d_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nk = ne00*ne01*ne02;
-
- GGML_ASSERT(nb00 == sizeof(float));
- GGML_ASSERT(nb10 == sizeof(float));
-
- if (ith == 0) {
- memset(params->wdata, 0, params->wsize);
-
- // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
- {
- float * const wdata = (float *) params->wdata + 0;
-
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
- float * dst_data = wdata + i01*ne00*ne02;
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- dst_data[i00*ne02 + i02] = src[i00];
- }
- }
- }
- }
-
- // prepare source data (src1)
- {
- float * const wdata = (float *) params->wdata + nk;
- float * dst_data = wdata;
-
- for (int64_t i11 = 0; i11 < ne11; i11++) {
- const float * const src = (float *)((char *) src1->data + i11*nb11);
- for (int64_t i10 = 0; i10 < ne10; i10++) {
- dst_data[i10*ne11 + i11] = src[i10];
- }
- }
- }
-
- // need to zero dst since we are accumulating into it
- memset(dst->data, 0, ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
-
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
- // total rows in dst
- const int nr = ne1;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float * const wdata = (float *) params->wdata + 0;
- float * const wdata_src = wdata + nk;
-
- for (int i1 = ir0; i1 < ir1; i1++) {
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
- float * wdata_kernel = wdata + i1*ne02*ne00;
- for (int i10 = 0; i10 < ne10; i10++) {
- const int i1n = i10*ne11;
- for (int i00 = 0; i00 < ne00; i00++) {
- float v = 0;
- ggml_vec_dot_f32(ne02, &v, 0,
- wdata_src + i1n, 0,
- wdata_kernel + i00*ne02, 0, 1);
- dst_data[i10*s0 + i00] += v;
- }
- }
- }
-}
-
-static void ggml_compute_forward_conv_transpose_1d(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_conv_transpose_1d_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_im2col_f32
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst: result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t N = is_2D ? ne13 : ne12;
- const int64_t IC = is_2D ? ne12 : ne11;
- const int64_t IH = is_2D ? ne11 : 1;
- const int64_t IW = ne10;
-
- const int64_t KH = is_2D ? ne01 : 1;
- const int64_t KW = ne00;
-
- const int64_t OH = is_2D ? ne2 : 1;
- const int64_t OW = ne1;
-
- int ofs0 = is_2D ? nb13 : nb12;
- int ofs1 = is_2D ? nb12 : nb11;
-
- GGML_ASSERT(nb10 == sizeof(float));
-
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
- {
- float * const wdata = (float *) dst->data;
-
- for (int64_t in = 0; in < N; in++) {
- for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
- for (int64_t iow = 0; iow < OW; iow++) {
- for (int64_t iic = ith; iic < IC; iic += nth) {
-
- // micro kernel
- float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
- for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
- for (int64_t ikw = 0; ikw < KW; ikw++) {
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
- } else {
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-
-// ggml_compute_forward_im2col_f16
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst: result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t N = is_2D ? ne13 : ne12;
- const int64_t IC = is_2D ? ne12 : ne11;
- const int64_t IH = is_2D ? ne11 : 1;
- const int64_t IW = ne10;
-
- const int64_t KH = is_2D ? ne01 : 1;
- const int64_t KW = ne00;
-
- const int64_t OH = is_2D ? ne2 : 1;
- const int64_t OW = ne1;
-
- int ofs0 = is_2D ? nb13 : nb12;
- int ofs1 = is_2D ? nb12 : nb11;
-
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb10 == sizeof(float));
-
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
- {
- ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
-
- for (int64_t in = 0; in < N; in++) {
- for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
- for (int64_t iow = 0; iow < OW; iow++) {
- for (int64_t iic = ith; iic < IC; iic += nth) {
-
- // micro kernel
- ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
- for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
- for (int64_t ikw = 0; ikw < KW; ikw++) {
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
- } else {
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_im2col(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- switch (dst->type) {
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_im2col_f16(params, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_im2col_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_im2col_back_f32
-
-static void ggml_compute_forward_im2col_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t N = is_2D ? ne3 : ne2;
- const int64_t IC = is_2D ? ne2 : ne1;
- const int64_t IH = is_2D ? ne1 : 1;
- const int64_t IW = ne0;
-
- const int64_t KH = is_2D ? ne01 : 1;
- const int64_t KW = ne00;
-
- const int64_t OH = is_2D ? ne12 : 1;
- const int64_t OW = ne11;
-
- int ofs0 = is_2D ? nb3 : nb2;
- int ofs1 = is_2D ? nb2 : nb1;
-
- GGML_ASSERT(nb0 == sizeof(float));
-
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
- {
- float * const wdata = (float *) dst->data;
-
- for (int64_t in = 0; in < N; in++) {
- for (int64_t iic = ith; iic < IC; iic += nth) {
- for (int64_t iih = 0; iih < IH; iih++) {
- for (int64_t iiw = 0; iiw < IW; iiw++) {
-
- // micro kernel
- float grad = 0.0f;
- for (int64_t ikh = 0; ikh < KH; ikh++) {
- for (int64_t ikw = 0; ikw < KW; ikw++) {
- // For s0 > 1 some values were skipped over in the forward pass.
- // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
- const int64_t tmpw = (iiw + p0 - ikw*d0);
- if (tmpw % s0 != 0) {
- continue;
- }
- const int64_t iow = tmpw / s0;
-
- // Equivalent logic as above except for s1.
- int64_t ioh;
- if (is_2D) {
- const int64_t tmph = iih + p1 - ikh*d1;
-
- if (tmph % s1 != 0) {
- continue;
- }
-
- ioh = tmph / s1;
- } else {
- ioh = 0;
- }
-
- if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
- continue;
- }
-
- const float * const src_data = (const float *) src1->data
- + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
- }
- }
- float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
- dst_data[iih*IW + iiw] = grad;
- }
- }
- }
- }
- }
-}
-
-// ggml_compute_forward_conv_transpose_2d
-
-static void ggml_compute_forward_conv_transpose_2d(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nk = ne00*ne01*ne02*ne03;
-
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
- GGML_ASSERT(nb10 == sizeof(float));
-
- if (ith == 0) {
- memset(params->wdata, 0, params->wsize);
-
- // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
- {
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
- ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
- for (int64_t i01 = 0; i01 < ne01; i01++) {
- for (int64_t i00 = 0; i00 < ne00; i00++) {
- dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
- }
- }
- }
- }
- }
-
- // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
- {
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
- for (int i12 = 0; i12 < ne12; i12++) {
- for (int i11 = 0; i11 < ne11; i11++) {
- const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
- ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
- for (int i10 = 0; i10 < ne10; i10++) {
- dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
- }
- }
- }
- }
-
- memset(dst->data, 0, ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
-
- const int32_t stride = ggml_get_op_params_i32(dst, 0);
-
- // total patches in dst
- const int np = ne2;
-
- // patches per thread
- const int dp = (np + nth - 1)/nth;
-
- // patch range for this thread
- const int ip0 = dp*ith;
- const int ip1 = MIN(ip0 + dp, np);
-
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
- ggml_fp16_t * const wdata_src = wdata + nk;
-
- for (int i2 = ip0; i2 < ip1; i2++) { // Cout
- float * dst_data = (float *)((char *) dst->data + i2*nb2);
- ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
- for (int i11 = 0; i11 < ne11; i11++) {
- for (int i10 = 0; i10 < ne10; i10++) {
- const int i1n = i11*ne10*ne12 + i10*ne12;
- for (int i01 = 0; i01 < ne01; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- float v = 0;
- ggml_vec_dot_f16(ne03, &v, 0,
- wdata_src + i1n, 0,
- wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
- dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
- }
- }
- }
- }
- }
-}
-
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
- const struct ggml_compute_params * params,
- const enum ggml_op_pool op,
- const int k,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src = dst->src[0];
-
- assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
- if (params->ith != 0) {
- return;
- }
-
- const char * cdata = (const char *)src->data;
- const char * const data_end = cdata + ggml_nbytes(src);
- float * drow = (float *)dst->data;
-
- const int64_t rs = dst->ne[0];
-
- while (cdata < data_end) {
- const void * srow = (const void *)cdata;
- int j = 0;
- for (int64_t i = 0; i < rs; ++i) {
- switch (op) {
- case GGML_OP_POOL_AVG: drow[i] = 0; break;
- case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
- for (int ki = 0; ki < k; ++ki) {
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
- switch (op) {
- case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
- case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
- ++j;
- }
- switch (op) {
- case GGML_OP_POOL_AVG: drow[i] /= k; break;
- case GGML_OP_POOL_MAX: break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
- }
-
- cdata += src->nb[1];
- drow += rs;
- }
-}
-
-// ggml_compute_forward_pool_1d
-
-static void ggml_compute_forward_pool_1d(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = opts[0];
- const int k0 = opts[1];
- const int s0 = opts[2];
- const int p0 = opts[3];
- GGML_ASSERT(p0 == 0); // padding not supported
- GGML_ASSERT(k0 == s0); // only s = k supported
-
- ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
-}
-
-// ggml_compute_forward_pool_2d
-
-static void ggml_compute_forward_pool_2d(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src = dst->src[0];
-
- assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
- if (params->ith != 0) {
- return;
- }
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = opts[0];
- const int k0 = opts[1];
- const int k1 = opts[2];
- const int s0 = opts[3];
- const int s1 = opts[4];
- const int p0 = opts[5];
- const int p1 = opts[6];
- const char * cdata = (const char*)src->data;
- const char * const data_end = cdata + ggml_nbytes(src);
-
- const int64_t px = dst->ne[0];
- const int64_t py = dst->ne[1];
- const int64_t pa = px * py;
-
- float * dplane = (float *)dst->data;
-
- const int ka = k0 * k1;
- const int offset0 = -p0;
- const int offset1 = -p1;
-
- while (cdata < data_end) {
- for (int oy = 0; oy < py; ++oy) {
- float * const drow = dplane + oy * px;
- for (int ox = 0; ox < px; ++ox) {
- float * const out = drow + ox;
- switch (op) {
- case GGML_OP_POOL_AVG: *out = 0; break;
- case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
-
- const int ix = offset0 + ox * s0;
- const int iy = offset1 + oy * s1;
-
- for (int ky = 0; ky < k1; ++ky) {
- if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
- const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
- for (int kx = 0; kx < k0; ++kx) {
- int j = ix + kx;
- if (j < 0 || j >= src->ne[0]) continue;
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
- switch (op) {
- case GGML_OP_POOL_AVG: *out += srow_j; break;
- case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
- }
- }
- switch (op) {
- case GGML_OP_POOL_AVG: *out /= ka; break;
- case GGML_OP_POOL_MAX: break;
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
- }
- }
- }
-
- cdata += src->nb[2];
- dplane += pa;
- }
-}
-
-// ggml_compute_forward_pool_2d_back
-
-static void ggml_compute_forward_pool_2d_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src = dst->src[0];
- const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
-
- assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
- if (params->ith != 0) {
- return;
- }
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = opts[0];
- const int k0 = opts[1];
- const int k1 = opts[2];
- const int s0 = opts[3];
- const int s1 = opts[4];
- const int p0 = opts[5];
- const int p1 = opts[6];
-
- char * cdata = (char *) dst->data;
- const char * cdataf = (const char *) dstf->data;
- const char * const data_end = cdata + ggml_nbytes(dst);
-
- GGML_ASSERT(params->ith == 0);
- memset(cdata, 0, ggml_nbytes(dst));
-
- const int64_t px = src->ne[0];
- const int64_t py = src->ne[1];
- const int64_t pa = px * py;
-
- const float * splane = (const float *) src->data;
-
- const int ka = k0 * k1;
- const int offset0 = -p0;
- const int offset1 = -p1;
-
- while (cdata < data_end) {
- for (int oy = 0; oy < py; ++oy) {
- const float * const srow = splane + oy * px;
- for (int ox = 0; ox < px; ++ox) {
- const float grad0 = srow[ox];
-
- const int ix = offset0 + ox * s0;
- const int iy = offset1 + oy * s1;
-
- if (op == GGML_OP_POOL_MAX) {
- float maxval = -FLT_MAX;
- int kxmax = -1;
- int kymax = -1;
-
- for (int ky = 0; ky < k1; ++ky) {
- if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
- continue;
- }
- const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
- for (int kx = 0; kx < k0; ++kx) {
- int j = ix + kx;
- if (j < 0 || j >= dst->ne[0]) {
- continue;
- }
-
- const float val = dst->type == GGML_TYPE_F32 ?
- ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
- if (val <= maxval) {
- continue;
- }
-
- maxval = val;
- kxmax = kx;
- kymax = ky;
- }
- }
-
- if (kxmax == -1 || kymax == -1) {
- continue;
- }
-
- void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
- const int j = ix + kxmax;
- if (dst->type == GGML_TYPE_F32) {
- ((float *) drow)[j] += grad0;
- } else {
- ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
- }
- } else if (op == GGML_OP_POOL_AVG) {
- const float grad = grad0 / ka;
-
- for (int ky = 0; ky < k1; ++ky) {
- if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
- continue;
- }
- void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
- for (int kx = 0; kx < k0; ++kx) {
- int j = ix + kx;
- if (j < 0 || j >= dst->ne[0]) {
- continue;
- }
-
- if (dst->type == GGML_TYPE_F32) {
- ((float *) drow)[j] += grad;
- } else {
- ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
- }
- }
- }
- } else {
- GGML_ASSERT(false);
- }
- }
- }
-
- cdata += dst->nb[2];
- cdataf += dst->nb[2];
- splane += pa;
- }
-}
-
-// ggml_compute_forward_upscale
-
-static void ggml_compute_forward_upscale_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const float sf0 = (float)ne0/src0->ne[0];
- const float sf1 = (float)ne1/src0->ne[1];
- const float sf2 = (float)ne2/src0->ne[2];
- const float sf3 = (float)ne3/src0->ne[3];
-
- // TODO: optimize
-
- for (int64_t i3 = 0; i3 < ne3; i3++) {
- const int64_t i03 = i3 / sf3;
- for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
- const int64_t i02 = i2 / sf2;
- for (int64_t i1 = 0; i1 < ne1; i1++) {
- const int64_t i01 = i1 / sf1;
- for (int64_t i0 = 0; i0 < ne0; i0++) {
- const int64_t i00 = i0 / sf0;
-
- const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
- float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
-
- *y = *x;
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_upscale(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_upscale_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-
-// ggml_compute_forward_pad
-
-static void ggml_compute_forward_pad_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
- GGML_ASSERT( dst->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- float * dst_ptr = (float *) dst->data;
-
- // TODO: optimize
-
- for (int64_t i2 = 0; i2 < ne2; ++i2) {
- for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- for (int64_t i3 = 0; i3 < ne3; ++i3) {
- const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
- dst_ptr[dst_idx] = *src_ptr;
- } else {
- dst_ptr[dst_idx] = 0;
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_pad(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_pad_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-
-// ggml_compute_forward_arange
-
-static void ggml_compute_forward_arange_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- GGML_ASSERT(dst->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const float start = ggml_get_op_params_f32(dst, 0);
- const float stop = ggml_get_op_params_f32(dst, 1);
- const float step = ggml_get_op_params_f32(dst, 2);
-
- const int64_t steps = (int64_t) ceilf((stop - start) / step);
-
- GGML_ASSERT(ggml_nelements(dst) == steps);
-
- for (int64_t i = ith; i < steps; i+= nth) {
- float value = start + step * i;
- ((float *)dst->data)[i] = value;
- }
-}
-
-static void ggml_compute_forward_arange(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- switch (dst->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_arange_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_timestep_embedding_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_ASSERT(src0->nb[0] == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const int dim = ggml_get_op_params_i32(dst, 0);
- const int max_period = ggml_get_op_params_i32(dst, 1);
-
- int half = dim / 2;
-
- for (int64_t i = 0; i < ne00; i++) {
- float * embed_data = (float *)((char *) dst->data + i*nb1);
- for (int64_t j = ith; j < half; j += nth) {
- float timestep = ((float *)src0->data)[i];
- float freq = (float)expf(-logf(max_period) * j / half);
- float arg = timestep * freq;
- embed_data[j] = cosf(arg);
- embed_data[j + half] = sinf(arg);
- }
- if (dim % 2 != 0 && ith == 0) {
- embed_data[dim] = 0.f;
- }
- }
-}
-
-static void ggml_compute_forward_timestep_embedding(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_timestep_embedding_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_argsort
-
-static void ggml_compute_forward_argsort_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- GGML_ASSERT(nb0 == sizeof(float));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t nr = ggml_nrows(src0);
-
- enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
-
- for (int64_t i = ith; i < nr; i += nth) {
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
- const float * src_data = (float *)((char *) src0->data + i*nb01);
-
- for (int64_t j = 0; j < ne0; j++) {
- dst_data[j] = j;
- }
-
- // C doesn't have a functional sort, so we do a bubble sort instead
- for (int64_t j = 0; j < ne0; j++) {
- for (int64_t k = j + 1; k < ne0; k++) {
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
- int32_t tmp = dst_data[j];
- dst_data[j] = dst_data[k];
- dst_data[k] = tmp;
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_argsort(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_argsort_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_flash_attn_ext
-
-static void ggml_compute_forward_flash_attn_ext_f16(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * q,
- const struct ggml_tensor * k,
- const struct ggml_tensor * v,
- const struct ggml_tensor * mask,
- struct ggml_tensor * dst) {
-
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t D = neq0;
- const int64_t N = neq1;
-
- GGML_ASSERT(ne0 == D);
- GGML_ASSERT(ne2 == N);
-
- // input tensor rows must be contiguous
- GGML_ASSERT(nbq0 == ggml_type_size(q->type));
- GGML_ASSERT(nbk0 == ggml_type_size(k->type));
- GGML_ASSERT(nbv0 == ggml_type_size(v->type));
-
- GGML_ASSERT(neq0 == D);
- GGML_ASSERT(nek0 == D);
- GGML_ASSERT(nev0 == D);
-
- GGML_ASSERT(neq1 == N);
- GGML_ASSERT(nev0 == D);
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- // broadcast factors
- const int64_t rk2 = neq2/nek2;
- const int64_t rk3 = neq3/nek3;
-
- const int64_t rv2 = neq2/nev2;
- const int64_t rv3 = neq3/nev3;
-
- // parallelize by q rows using ggml_vec_dot_f32
-
- // total rows in q
- const int nr = neq1*neq2*neq3;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- float scale = 1.0f;
- float max_bias = 0.0f;
- float logit_softcap = 0.0f;
-
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
- memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
-
- if (logit_softcap != 0) {
- scale /= logit_softcap;
- }
-
- const uint32_t n_head = neq2;
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
- enum ggml_type const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
- ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits(k_vec_dot_type)->from_float;
- ggml_vec_dot_t const kq_vec_dot = type_traits_cpu[k->type].vec_dot;
- ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
-
- GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
- GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
-
- // loop over n_batch and n_head
- for (int ir = ir0; ir < ir1; ++ir) {
- // q indices
- const int iq3 = ir/(neq2*neq1);
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
-
- const uint32_t h = iq2; // head index
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
- float S = 0.0f; // sum
- float M = -INFINITY; // maximum KQ value
-
- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
- float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
-
- if (v->type == GGML_TYPE_F16) {
- memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
- } else {
- memset(VKQ32, 0, D*sizeof(float));
- }
-
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
-
- // k indices
- const int ik3 = iq3 / rk3;
- const int ik2 = iq2 / rk2;
-
- // v indices
- const int iv3 = iq3 / rv3;
- const int iv2 = iq2 / rv2;
-
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
- q_to_vec_dot(pq, Q_q, D);
-
- // online softmax / attention
- // loop over n_kv and n_head_kv
- // ref: https://arxiv.org/pdf/2112.05682.pdf
- for (int64_t ic = 0; ic < nek1; ++ic) {
- const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
- if (mv == -INFINITY) {
- continue;
- }
-
- float s; // KQ value
-
- const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
- kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
-
- s = s*scale; // scale KQ value
-
- if (logit_softcap != 0.0f) {
- s = logit_softcap*tanhf(s);
- }
-
- s += mv; // apply mask
-
- const float Mold = M;
-
- float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
- float vs = 1.0f; // post-softmax KQ value, expf(s - M)
-
- const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
-
- if (v->type == GGML_TYPE_F16) {
- if (s > M) {
- // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
- M = s;
- ms = expf(Mold - M);
-
- // V = V*expf(Mold - M)
- ggml_vec_scale_f16(D, VKQ16, ms);
- } else {
- // no new maximum, ms == 1.0f, vs != 1.0f
- vs = expf(s - M);
- }
-
- // V += v*expf(s - M)
- ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
- } else {
- if (s > M) {
- // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
- M = s;
- ms = expf(Mold - M);
-
- // V = V*expf(Mold - M)
- ggml_vec_scale_f32(D, VKQ32, ms);
- } else {
- // no new maximum, ms == 1.0f, vs != 1.0f
- vs = expf(s - M);
- }
-
- v_to_float(v_data, V32, D);
-
- // V += v*expf(s - M)
- ggml_vec_mad_f32(D, VKQ32, V32, vs);
- }
-
- S = S*ms + vs; // scale and increment sum with partial sum
- }
-
- if (v->type == GGML_TYPE_F16) {
- for (int64_t d = 0; d < D; ++d) {
- VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
- }
- }
-
- // V /= S
- const float S_inv = 1.0f/S;
- ggml_vec_scale_f32(D, VKQ32, S_inv);
-
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
-
- // original
- //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
-
- // permute(0, 2, 1, 3)
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
- }
-}
-
-static void ggml_compute_forward_flash_attn_ext(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * q,
- const struct ggml_tensor * k,
- const struct ggml_tensor * v,
- const struct ggml_tensor * mask,
- struct ggml_tensor * dst) {
- switch (dst->op_params[3]) {
- case GGML_PREC_DEFAULT:
- case GGML_PREC_F32:
- {
- // uses F32 accumulators
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_flash_attn_back
-
-static void ggml_compute_forward_flash_attn_back_f32(
- const struct ggml_compute_params * params,
- const bool masked,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * q = dst->src[0];
- const struct ggml_tensor * k = dst->src[1];
- const struct ggml_tensor * v = dst->src[2];
- const struct ggml_tensor * d = dst->src[3];
-
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
- GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
- GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t D = neq0;
- const int64_t N = neq1;
- const int64_t P = nek1 - N;
- const int64_t M = P + N;
-
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
- const int mxDM = MAX(D, Mup);
-
- // GGML_ASSERT(ne0 == D);
- // GGML_ASSERT(ne1 == N);
- GGML_ASSERT(P >= 0);
-
- GGML_ASSERT(nbq0 == sizeof(float));
- GGML_ASSERT(nbk0 == sizeof(float));
- GGML_ASSERT(nbv0 == sizeof(float));
-
- GGML_ASSERT(neq0 == D);
- GGML_ASSERT(nek0 == D);
- GGML_ASSERT(nev1 == D);
- GGML_ASSERT(ned0 == D);
-
- GGML_ASSERT(neq1 == N);
- GGML_ASSERT(nek1 == N + P);
- GGML_ASSERT(nev1 == D);
- GGML_ASSERT(ned1 == N);
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- if (ith == 0) {
- memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
- }
- ggml_barrier(params->threadpool);
-
- const int64_t elem_q = ggml_nelements(q);
- const int64_t elem_k = ggml_nelements(k);
-
- enum ggml_type result_type = dst->type;
- GGML_ASSERT(ggml_blck_size(result_type) == 1);
- const size_t tsize = ggml_type_size(result_type);
-
- const size_t offs_q = 0;
- const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
- const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
-
- void * grad_q = (char *) dst->data;
- void * grad_k = (char *) dst->data + offs_k;
- void * grad_v = (char *) dst->data + offs_v;
-
- const size_t nbgq1 = nb0*neq0;
- const size_t nbgq2 = nb0*neq0*neq1;
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
- const size_t nbgk1 = nb0*nek0;
- const size_t nbgk2 = nb0*nek0*nek1;
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
- const size_t nbgv1 = nb0*nev0;
- const size_t nbgv2 = nb0*nev0*nev1;
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
- // parallelize by k rows using ggml_vec_dot_f32
-
- // total rows in k
- const int nr = nek2*nek3;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- const float scale = 1.0f/sqrtf(D);
-
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
-
- // how often k2 (and v2) is repeated in q2
- int nrep = neq2/nek2;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // q indices
- const int ik3 = ir/(nek2);
- const int ik2 = ir - ik3*nek2;
-
- const int iq3 = ik3;
- const int id3 = ik3;
- const int iv3 = ik3;
- const int iv2 = ik2;
-
- for (int irep = 0; irep < nrep; ++irep) {
- const int iq2 = ik2 + irep*nek2;
- const int id2 = iq2;
-
- // (ik2 + irep*nek2) % nek2 == ik2
- for (int iq1 = 0; iq1 < neq1; ++iq1) {
- const int id1 = iq1;
-
- // not sure about CACHE_LINE_SIZE_F32..
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
-
- for (int i = M; i < Mup; ++i) {
- S[i] = -INFINITY;
- }
-
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
- // k indices
- const int ik1 = ic;
-
- // S indices
- const int i1 = ik1;
-
- ggml_vec_dot_f32(neq0,
- S + i1, 0,
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
- }
-
- // scale
- ggml_vec_scale_f32(masked_begin, S, scale);
-
- for (int64_t i = masked_begin; i < M; i++) {
- S[i] = -INFINITY;
- }
-
- // softmax
- // exclude known -INF S[..] values from max and loop
- // dont forget to set their SM values to zero
- {
- float max = -INFINITY;
- ggml_vec_max_f32(masked_begin, &max, S);
-
- ggml_float sum = 0.0;
- {
-#ifdef GGML_SOFT_MAX_ACCELERATE
- max = -max;
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
- vvexpf(SM, SM, &Mup);
- ggml_vec_sum_f32(Mup, &sum, SM);
-#else
- sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
-#endif
- }
-
- assert(sum > 0.0);
-
- sum = 1.0/sum;
- ggml_vec_scale_f32(masked_begin, SM, sum);
-
- }
-
- // step-by-step explanation
- {
- // forward-process shape grads from backward process
- // parallel_for ik2,ik3:
- // for irep:
- // iq2 = ik2 + irep*nek2
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
- // for iq1:
- // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
- // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
- // S0 = -Inf [D,1,1,1]
- // ~S1[i] = dot(kcur[:D,i], qcur)
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
- // ~S5[i] = dot(vcur[:,i], S4)
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
- // dst backward-/ grad[dst] = d
- //
- // output gradients with their dependencies:
- //
- // grad[kcur] = grad[S1].T @ qcur
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S4] = grad[S5] @ vcur
- // grad[S4] = d[:D,id1,id2,id3] @ vcur
- // grad[qcur] = grad[S1] @ kcur
- // grad[vcur] = grad[S5].T @ S4
- // grad[vcur] = d[:D,id1,id2,id3].T @ S4
- //
- // in post-order:
- //
- // S1 = qcur @ kcur.T
- // S2 = S1 * scale
- // S3 = diag_mask_inf(S2, P)
- // S4 = softmax(S3)
- // grad[S4] = d[:D,id1,id2,id3] @ vcur
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[qcur] = grad[S1] @ kcur
- // grad[kcur] = grad[S1].T @ qcur
- // grad[vcur] = d[:D,id1,id2,id3].T @ S4
- //
- // using less variables (SM=S4):
- //
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
- // SM = softmax(S)
- // S = d[:D,iq1,iq2,iq3] @ vcur
- // dot_SM_gradSM = dot(SM, S)
- // S = SM * (S - dot(SM, S))
- // S = diag_mask_zero(S, P) * scale
- //
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
- // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
- }
-
- // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
- // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
- // for ic:
- // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
- // exclude known future zero S[..] values from operation
- ggml_vec_set_f32(masked_begin, S, 0);
- for (int64_t ic = 0; ic < D; ++ic) {
- ggml_vec_mad_f32(masked_begin,
- S,
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
- *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
- }
-
- // S = SM * (S - dot(SM, S))
- float dot_SM_gradSM = 0;
- ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
- ggml_vec_mul_f32 (masked_begin, S, S, SM);
-
- // S = diag_mask_zero(S, P) * scale
- // already done by above ggml_vec_set_f32
-
- // exclude known zero S[..] values from operation
- ggml_vec_scale_f32(masked_begin, S, scale);
-
- // S shape [M,1]
- // SM shape [M,1]
- // kcur shape [D,M]
- // qcur shape [D,1]
- // vcur shape [M,D]
-
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
- // for ic:
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
- // exclude known zero S[..] values from loop
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
- (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
- S[ic]);
- }
-
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
- // for ic:
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
- // exclude known zero S[..] values from loop
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
- S[ic]);
- }
-
- // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
- // for ic:
- // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
- // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
- // exclude known zero SM[..] values from mad
- for (int64_t ic = 0; ic < D; ++ic) {
- ggml_vec_mad_f32(masked_begin,
- (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
- SM,
- *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_flash_attn_back(
- const struct ggml_compute_params * params,
- const bool masked,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * q = dst->src[0];
-
- switch (q->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_ssm_conv
-
-static void ggml_compute_forward_ssm_conv_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0]; // conv_x
- const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nc = src1->ne[0]; // d_conv
- const int ncs = src0->ne[0]; // d_conv - 1 + n_t
- const int nr = src0->ne[1]; // d_inner
- const int n_t = dst->ne[1]; // tokens per sequence
- const int n_s = dst->ne[2]; // number of sequences in the batch
-
- GGML_ASSERT( dst->ne[0] == nr);
- GGML_ASSERT(src0->nb[0] == sizeof(float));
- GGML_ASSERT(src1->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
- const int ir = ir1 - ir0;
-
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- // {d_conv - 1 + n_t, d_inner, n_seqs}
- // sliding window
- const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
- const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
- float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
-
- // TODO: transpose the output for smaller strides for big batches?
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- // rowwise dot product
- // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
- float sumf = 0.0f;
-
- // d_conv
- for (int i0 = 0; i0 < nc; ++i0) {
- sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
- }
- x[i1] = sumf;
- }
- }
- }
-}
-
-static void ggml_compute_forward_ssm_conv(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- switch (dst->src[0]->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_ssm_conv_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_ssm_scan
-
-static void ggml_compute_forward_ssm_scan_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0]; // s
- const struct ggml_tensor * src1 = dst->src[1]; // x
- const struct ggml_tensor * src2 = dst->src[2]; // dt
- const struct ggml_tensor * src3 = dst->src[3]; // A
- const struct ggml_tensor * src4 = dst->src[4]; // B
- const struct ggml_tensor * src5 = dst->src[5]; // C
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int64_t nc = src0->ne[0]; // d_state
- const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
-
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
- GGML_ASSERT(src0->nb[0] == sizeof(float));
- GGML_ASSERT(src1->nb[0] == sizeof(float));
- GGML_ASSERT(src2->nb[0] == sizeof(float));
- GGML_ASSERT(src3->nb[0] == sizeof(float));
- GGML_ASSERT(src4->nb[0] == sizeof(float));
- GGML_ASSERT(src5->nb[0] == sizeof(float));
- // required for the dot product between s and C
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
- // required for per-sequence offsets for states
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
- // required to get correct offset for state destination (i.e. src1->nb[3])
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
- const int ir = ir1 - ir0;
-
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
-
- // use the output as the source for the next token-wise iterations
- if (i2 > 0) { s0 = s; }
-
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
- float x_dt = x[i1] * dt_soft_plus;
- float sumf = 0.0f;
- // d_state
- for (int i0 = 0; i0 < nc; ++i0) {
- int i = i0 + i1*nc;
- // state = prev_state * dA + dB * x
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
- // y = rowwise_dotprod(state, C)
- sumf += state * C[i0];
- s[i] = state;
- }
- y[i1] = sumf;
- }
- }
- }
-}
-
-static void ggml_compute_forward_ssm_scan(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- switch (dst->src[0]->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_ssm_scan_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_win_part
-
-static void ggml_compute_forward_win_part_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- UNUSED(params);
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
-
- const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t w = ((const int32_t *)(dst->op_params))[2];
-
- assert(ne00 == ne0);
- assert(ne3 == nep0*nep1);
-
- // TODO: optimize / multi-thread
- for (int py = 0; py < nep1; ++py) {
- for (int px = 0; px < nep0; ++px) {
- const int64_t i3 = py*nep0 + px;
- for (int64_t i2 = 0; i2 < ne2; ++i2) {
- for (int64_t i1 = 0; i1 < ne1; ++i1) {
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- const int64_t i02 = py*w + i2;
- const int64_t i01 = px*w + i1;
- const int64_t i00 = i0;
-
- const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
- const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
-
- if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
- ((float *) dst->data)[i] = 0.0f;
- } else {
- ((float *) dst->data)[i] = ((float *) src0->data)[j];
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_win_part(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_win_part_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_win_unpart
-
-static void ggml_compute_forward_win_unpart_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- UNUSED(params);
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
-
- const int32_t w = ((const int32_t *)(dst->op_params))[0];
-
- // padding
- const int px = (w - ne1%w)%w;
- //const int py = (w - ne2%w)%w;
-
- const int npx = (px + ne1)/w;
- //const int npy = (py + ne2)/w;
-
- assert(ne0 == ne00);
-
- // TODO: optimize / multi-thread
- for (int64_t i2 = 0; i2 < ne2; ++i2) {
- for (int64_t i1 = 0; i1 < ne1; ++i1) {
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- const int ip2 = i2/w;
- const int ip1 = i1/w;
-
- const int64_t i02 = i2%w;
- const int64_t i01 = i1%w;
- const int64_t i00 = i0;
-
- const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
- const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
-
- ((float *) dst->data)[j] = ((float *) src0->data)[i];
- }
- }
- }
-}
-
-static void ggml_compute_forward_win_unpart(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_win_unpart_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-//gmml_compute_forward_unary
-
-static void ggml_compute_forward_unary(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const enum ggml_unary_op op = ggml_get_unary_op(dst);
-
- switch (op) {
- case GGML_UNARY_OP_ABS:
- {
- ggml_compute_forward_abs(params, dst);
- } break;
- case GGML_UNARY_OP_SGN:
- {
- ggml_compute_forward_sgn(params, dst);
- } break;
- case GGML_UNARY_OP_NEG:
- {
- ggml_compute_forward_neg(params, dst);
- } break;
- case GGML_UNARY_OP_STEP:
- {
- ggml_compute_forward_step(params, dst);
- } break;
- case GGML_UNARY_OP_TANH:
- {
- ggml_compute_forward_tanh(params, dst);
- } break;
- case GGML_UNARY_OP_ELU:
- {
- ggml_compute_forward_elu(params, dst);
- } break;
- case GGML_UNARY_OP_RELU:
- {
- ggml_compute_forward_relu(params, dst);
- } break;
- case GGML_UNARY_OP_SIGMOID:
- {
- ggml_compute_forward_sigmoid(params, dst);
- } break;
- case GGML_UNARY_OP_GELU:
- {
- ggml_compute_forward_gelu(params, dst);
- } break;
- case GGML_UNARY_OP_GELU_QUICK:
- {
- ggml_compute_forward_gelu_quick(params, dst);
- } break;
- case GGML_UNARY_OP_SILU:
- {
- ggml_compute_forward_silu(params, dst);
- } break;
- case GGML_UNARY_OP_HARDSWISH:
- {
- ggml_compute_forward_hardswish(params, dst);
- } break;
- case GGML_UNARY_OP_HARDSIGMOID:
- {
- ggml_compute_forward_hardsigmoid(params, dst);
- } break;
- case GGML_UNARY_OP_EXP:
- {
- ggml_compute_forward_exp(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_get_rel_pos
-
-static void ggml_compute_forward_get_rel_pos_f16(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- UNUSED(params);
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
-
- GGML_TENSOR_UNARY_OP_LOCALS
-
- const int64_t w = ne1;
-
- ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
- ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
-
- for (int64_t i2 = 0; i2 < ne2; ++i2) {
- for (int64_t i1 = 0; i1 < ne1; ++i1) {
- const int64_t pos = (w - i1 - 1) + i2;
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
- dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
- }
- }
- }
-}
-
-static void ggml_compute_forward_get_rel_pos(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- {
- ggml_compute_forward_get_rel_pos_f16(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_add_rel_pos
-
-static void ggml_compute_forward_add_rel_pos_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * src2 = dst->src[2];
-
- const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
- if (!inplace) {
- if (params->ith == 0) {
- memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
- }
- ggml_barrier(params->threadpool);
- }
- // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
-
- float * src1_data = (float *) src1->data;
- float * src2_data = (float *) src2->data;
- float * dst_data = (float *) dst->data;
-
- const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
- const int64_t ne12 = src1->ne[2];
- const int64_t ne13 = src1->ne[3];
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- // total patches in dst
- const int np = ne13;
-
- // patches per thread
- const int dp = (np + nth - 1)/nth;
-
- // patch range for this thread
- const int ip0 = dp*ith;
- const int ip1 = MIN(ip0 + dp, np);
-
- for (int64_t i13 = ip0; i13 < ip1; ++i13) {
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
- const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
- for (int64_t i10 = 0; i10 < ne10; ++i10) {
- const int64_t jp0 = jp1 + i10;
- const float src1_e = src1_data[jp0];
- const float src2_e = src2_data[jp0];
-
- const int64_t jdh = jp0 * ne10;
- const int64_t jdw = jdh - (ne10 - 1) * i10;
-
- for (int64_t j = 0; j < ne10; ++j) {
- dst_data[jdh + j ] += src2_e;
- dst_data[jdw + j*ne10] += src1_e;
- }
- }
- }
- }
- }
-}
-
-static void ggml_compute_forward_add_rel_pos(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_add_rel_pos_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_rwkv_wkv6
-
-static void ggml_compute_forward_rwkv_wkv6_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
- const int64_t T = dst->src[1]->ne[3];
- const int64_t C = dst->ne[0];
- const int64_t HEADS = dst->src[1]->ne[2];
- const int64_t n_seqs = dst->src[5]->ne[1];
- const int64_t head_size = C / HEADS;
-
- float * dst_data = (float *) dst->data;
- float * state = ((float *) dst->data) + C * T;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- if (ith >= HEADS) {
- return;
- }
-
- const int h_start = (HEADS * ith) / nth;
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
- (HEADS * (ith + 1)) / nth : HEADS;
-
- float * k = (float *) dst->src[0]->data;
- float * v = (float *) dst->src[1]->data;
- float * r = (float *) dst->src[2]->data;
- float * time_faaaa = (float *) dst->src[3]->data;
- float * time_decay = (float *) dst->src[4]->data;
-
- size_t t_stride = HEADS * head_size; // Same to C
-
- size_t h_stride = C / HEADS;
- GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
- size_t h_stride_2d = head_size * head_size;
-
- if (ith == 0) {
- memset(dst_data, 0, T * C * sizeof(float));
- }
- ggml_barrier(params->threadpool);
-
-
- #if defined(__AVX__) && !defined(__AVX512F__)
- #define GGML_F32X GGML_F32x8
- #define GGML_F32X_SET1 GGML_F32x8_SET1
- #define GGML_F32X_LOAD GGML_F32x8_LOAD
- #define GGML_F32X_STORE GGML_F32x8_STORE
- #define GGML_F32X_MUL GGML_F32x8_MUL
- #define GGML_F32X_FMA GGML_F32x8_FMA
- #define WKV_VECTOR_SIZE 8
- #elif defined(__AVX512F__)
- #define GGML_F32X GGML_F32x16
- #define GGML_F32X_SET1 GGML_F32x16_SET1
- #define GGML_F32X_LOAD GGML_F32x16_LOAD
- #define GGML_F32X_STORE GGML_F32x16_STORE
- #define GGML_F32X_MUL GGML_F32x16_MUL
- #define GGML_F32X_FMA GGML_F32x16_FMA
- #define WKV_VECTOR_SIZE 16
- #elif defined(__ARM_NEON) && defined(__aarch64__)
- #define GGML_F32X GGML_F32x4
- #define GGML_F32X_SET1 GGML_F32x4_SET1
- #define GGML_F32X_LOAD GGML_F32x4_LOAD
- #define GGML_F32X_STORE GGML_F32x4_STORE
- #define GGML_F32X_MUL GGML_F32x4_MUL
- #define GGML_F32X_FMA GGML_F32x4_FMA
- #define WKV_VECTOR_SIZE 4
- #endif
-
- #ifdef WKV_VECTOR_SIZE
- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
-
- for (int64_t t = 0; t < T; t++) {
- size_t t_offset = t * t_stride;
- size_t state_offset = head_size * C * (t / (T / n_seqs));
- float * state_cur = state + state_offset;
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
- for (int64_t h = h_start; h < h_end; h++) {
- size_t h_offset = h * h_stride;
- size_t t_h_offset = t_offset + h_offset;
- size_t h_2d_offset = h * h_stride_2d;
-
- for (int64_t i = 0; i < head_size; i++) {
- size_t t_h_i_offset = t_h_offset + i;
- size_t h_i_offset = h_offset + i;
- size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
- float k_val = k[t_h_i_offset];
- float r_val = r[t_h_i_offset];
- float time_faaaa_val = time_faaaa[h_i_offset];
- float time_decay_val = time_decay[t_h_i_offset];
-
- // Broadcast scalar values to vectors
- GGML_F32X k_vec = GGML_F32X_SET1(k_val);
- GGML_F32X r_vec = GGML_F32X_SET1(r_val);
- GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
- GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
-
- for (int64_t j = 0; j < vec_count; j++) {
- size_t base_j = j * WKV_VECTOR_SIZE;
- size_t t_h_j_offset = t_h_offset + base_j;
- size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
-
- // Load x elements at once
- GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
- GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
- GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
-
- // Compute kv = v * k
- GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
-
- // Compute temp = kv * time_faaaa + prev_state
- GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
-
- // Update dst: dst += temp * r
- dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
- GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
-
- // Update state: state = prev_state * time_decay + kv
- GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
- GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
- }
-
- // Handle remaining elements, this will not be used.
- for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
- size_t t_h_j_offset = t_h_offset + j;
- size_t h_2d_i_j_offset = h_2d_i_offset + j;
- float v_val = v[t_h_j_offset];
- float kv_val = v_val * k_val;
- float prev_state_val = state_prev[h_2d_i_j_offset];
- float temp_val = kv_val * time_faaaa_val + prev_state_val;
- dst_data[t_h_j_offset] += temp_val * r_val;
- state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
- }
- }
- }
- }
-
- #else
- // basically fused operations:
- // dst = r @ (time_faaaa * (k @ v) + state),
- // state = time_decay * state + (k @ v),
- // recursive through each token
- for (int64_t t = 0; t < T; t++) {
- size_t t_offset = t * t_stride;
- size_t state_offset = head_size * C * (t / (T / n_seqs));
- float * state_cur = state + state_offset;
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
- for (int64_t h = h_start; h < h_end; h++) {
- size_t h_offset = h * h_stride;
- size_t t_h_offset = t_offset + h_offset;
- size_t h_2d_offset = h * h_stride_2d;
-
- for (int64_t i = 0; i < head_size; i++) {
- size_t t_h_i_offset = t_h_offset + i;
- size_t h_i_offset = h_offset + i;
- size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
- float k_val = k[t_h_i_offset];
- float r_val = r[t_h_i_offset];
- float time_faaaa_val = time_faaaa[h_i_offset];
- // RWKV v6: different time_decay for each token.
- float time_decay_val = time_decay[t_h_i_offset];
-
- for (int64_t j = 0; j < head_size; j++) {
- size_t t_h_j_offset = t_h_offset + j;
- size_t h_2d_i_j_offset = h_2d_i_offset + j;
-
- float v_val = v[t_h_j_offset];
- float kv_val = v_val * k_val;
- float prev_state_val = state_prev[h_2d_i_j_offset];
- float temp_val = kv_val * time_faaaa_val + prev_state_val;
- dst_data[t_h_j_offset] += temp_val * r_val;
- state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
- }
- }
- }
- }
- #endif
-}
-
-
-static void ggml_compute_forward_rwkv_wkv6(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_rwkv_wkv6_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_map_unary
-
-static void ggml_compute_forward_map_unary_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_unary_op_f32_t fun) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- fun(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
- }
-}
-
-static void ggml_compute_forward_map_unary(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_unary_op_f32_t fun) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_map_unary_f32(params, dst, fun);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_map_binary
-
-static void ggml_compute_forward_map_binary_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_binary_op_f32_t fun) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- if (params->ith != 0) {
- return;
- }
-
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(src1));
- assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int n = ggml_nrows(src0);
- const int nc = src0->ne[0];
-
- for (int i = 0; i < n; i++) {
- fun(nc,
- (float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])),
- (float *) ((char *) src1->data + i*(src1->nb[1])));
- }
-}
-
-static void ggml_compute_forward_map_binary(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_binary_op_f32_t fun) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_map_binary_f32(params, dst, fun);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_custom1_op_f32_t fun) {
-
- const struct ggml_tensor * a = dst->src[0];
-
- if (params->ith != 0) {
- return;
- }
-
- fun(dst, a);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_custom2_op_f32_t fun) {
-
- const struct ggml_tensor * a = dst->src[0];
- const struct ggml_tensor * b = dst->src[1];
-
- if (params->ith != 0) {
- return;
- }
-
- fun(dst, a, b);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst,
- const ggml_custom3_op_f32_t fun) {
-
- const struct ggml_tensor * a = dst->src[0];
- const struct ggml_tensor * b = dst->src[1];
- const struct ggml_tensor * c = dst->src[1];
-
- if (params->ith != 0) {
- return;
- }
-
- fun(dst, a, b, c);
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * a = dst->src[0];
-
- struct ggml_map_custom1_op_params p;
- memcpy(&p, dst->op_params, sizeof(p));
-
- p.fun(dst, a, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * a = dst->src[0];
- const struct ggml_tensor * b = dst->src[1];
-
- struct ggml_map_custom2_op_params p;
- memcpy(&p, dst->op_params, sizeof(p));
-
- p.fun(dst, a, b, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * a = dst->src[0];
- const struct ggml_tensor * b = dst->src[1];
- const struct ggml_tensor * c = dst->src[2];
-
- struct ggml_map_custom3_op_params p;
- memcpy(&p, dst->op_params, sizeof(p));
-
- p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_cross_entropy_loss
-
-static void ggml_compute_forward_cross_entropy_loss_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
- GGML_ASSERT(ggml_are_same_shape(src0, src1));
- GGML_ASSERT(ggml_is_scalar(dst));
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
- // TODO: handle transposed/permuted matrices
- const int64_t nc = src0->ne[0];
- const int64_t nr = ggml_nrows(src0);
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- float * sums = (float *) params->wdata;
- float * st = ((float *) params->wdata) + nth + ith*nc;
- float sum_thread = 0.0f;
-
- GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
-
- // rows per thread
- const int64_t dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int64_t ir0 = dr*ith;
- const int64_t ir1 = MIN(ir0 + dr, nr);
-
- for (int64_t i1 = ir0; i1 < ir1; ++i1) {
- const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
- const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
-
-#ifndef NDEBUG
- for (int64_t i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(s0[i]));
- assert(!isnan(s1[i]));
- }
-#endif
-
- float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, s0);
- const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
- assert(sum_softmax >= 0.0);
-
- ggml_vec_add1_f32(nc, st, st, -sum_softmax);
- ggml_vec_mul_f32(nc, st, st, s1);
-
- float sum_st = 0.0f;
- ggml_vec_sum_f32(nc, &sum_st, st);
- sum_thread += sum_st;
-
-#ifndef NDEBUG
- for (int64_t i = 0; i < nc; ++i) {
- assert(!isnan(st[i]));
- assert(!isinf(st[i]));
- }
-#endif
- }
- sums[ith] = sum_thread;
- ggml_barrier(params->threadpool);
-
- if (ith == 0) {
- float * dp = (float *) dst->data;
- ggml_vec_sum_f32(nth, dp, sums);
- dp[0] *= -1.0f / (float) nr;
- }
-}
-
-static void ggml_compute_forward_cross_entropy_loss(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_cross_entropy_loss_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// ggml_compute_forward_cross_entropy_loss_back
-
-static void ggml_compute_forward_cross_entropy_loss_back_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * opt0 = dst->src[2];
-
- GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(opt0));
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
- const int64_t ith = params->ith;
- const int64_t nth = params->nth;
-
- // TODO: handle transposed/permuted matrices
- const int64_t nc = src0->ne[0];
- const int64_t nr = ggml_nrows(src0);
-
- // rows per thread
- const int64_t dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int64_t ir0 = dr*ith;
- const int64_t ir1 = MIN(ir0 + dr, nr);
-
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
-
- for (int64_t i1 = ir0; i1 < ir1; i1++) {
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
-
-#ifndef NDEBUG
- for (int64_t i = 0; i < nc; ++i) {
- //printf("p[%d] = %f\n", i, p[i]);
- assert(!isnan(s0[i]));
- assert(!isnan(s1[i]));
- }
-#endif
-
- // soft_max
- float max = -INFINITY;
- ggml_vec_max_f32(nc, &max, s0);
- ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
- assert(sum > 0.0);
- ggml_vec_scale_f32(nc, ds0, 1.0/sum);
-
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
- ggml_vec_sub_f32(nc, ds0, ds0, s1);
- ggml_vec_scale_f32(nc, ds0, d_by_nr);
-
-#ifndef NDEBUG
- for (int64_t i = 0; i < nc; ++i) {
- assert(!isnan(ds0[i]));
- assert(!isinf(ds0[i]));
- }
-#endif
- }
-}
-
-static void ggml_compute_forward_cross_entropy_loss_back(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static void ggml_compute_forward_opt_step_adamw_f32(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src0_grad = dst->src[1];
- const struct ggml_tensor * src0_grad_m = dst->src[2];
- const struct ggml_tensor * src0_grad_v = dst->src[3];
- GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- const int nr = ggml_nrows(src0);
-
- GGML_TENSOR_UNARY_OP_LOCALS
- GGML_ASSERT(nb00 == sizeof(float));
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- /* const float gnorm = 1.0f; */
- int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
- const float alpha = ggml_get_op_params_f32(dst, 2);
- const float beta1 = ggml_get_op_params_f32(dst, 3);
- const float beta2 = ggml_get_op_params_f32(dst, 4);
- const float eps = ggml_get_op_params_f32(dst, 5);
- const float wd = ggml_get_op_params_f32(dst, 6);
-
- const float beta1h = alpha/(1.0f - powf(beta1, iter));
- const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
-
- for (int ir = ir0; ir < ir1; ++ir) {
- const int64_t i03 = ir/(ne02*ne01);
- const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
- const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
-
- float * w = (float *) ((char *) src0->data + offset); // weight
- const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
- float * m = (float *) ((char *) src0_grad_m->data + offset);
- float * v = (float *) ((char *) src0_grad_v->data + offset);
-
- for (int i00 = 0; i00 < ne00; ++i00) {
- m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
- v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
-
- const float mh = m[i00]*beta1h;
- const float vh = sqrtf(v[i00]*beta2h) + eps;
-
- // The weight decay is applied independently of the Adam momenta m and v.
- // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
- // See: https://arxiv.org/pdf/1711.05101v3.pdf
- w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
- }
- }
-
- ggml_barrier(params->threadpool);
- if (ith != 0) {
- return;
- }
-
- iter++;
- memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
-}
-
-static void ggml_compute_forward_opt_step_adamw(
- const struct ggml_compute_params * params,
- struct ggml_tensor * dst) {
-
- const struct ggml_tensor * src0 = dst->src[0];
-
- switch (src0->type) {
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_opt_step_adamw_f32(params, dst);
- } break;
- default:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-/////////////////////////////////
-
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
- GGML_ASSERT(params);
-
- if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
- return;
- }
-
- switch (tensor->op) {
- case GGML_OP_DUP:
- {
- ggml_compute_forward_dup(params, tensor);
- } break;
- case GGML_OP_ADD:
- {
- ggml_compute_forward_add(params, tensor);
- } break;
- case GGML_OP_ADD1:
- {
- ggml_compute_forward_add1(params, tensor);
- } break;
- case GGML_OP_ACC:
- {
- ggml_compute_forward_acc(params, tensor);
- } break;
- case GGML_OP_SUB:
- {
- ggml_compute_forward_sub(params, tensor);
- } break;
- case GGML_OP_MUL:
- {
- ggml_compute_forward_mul(params, tensor);
- } break;
- case GGML_OP_DIV:
- {
- ggml_compute_forward_div(params, tensor);
- } break;
- case GGML_OP_SQR:
- {
- ggml_compute_forward_sqr(params, tensor);
- } break;
- case GGML_OP_SQRT:
- {
- ggml_compute_forward_sqrt(params, tensor);
- } break;
- case GGML_OP_LOG:
- {
- ggml_compute_forward_log(params, tensor);
- } break;
- case GGML_OP_SIN:
- {
- ggml_compute_forward_sin(params, tensor);
- } break;
- case GGML_OP_COS:
- {
- ggml_compute_forward_cos(params, tensor);
- } break;
- case GGML_OP_SUM:
- {
- ggml_compute_forward_sum(params, tensor);
- } break;
- case GGML_OP_SUM_ROWS:
- {
- ggml_compute_forward_sum_rows(params, tensor);
- } break;
- case GGML_OP_MEAN:
- {
- ggml_compute_forward_mean(params, tensor);
- } break;
- case GGML_OP_ARGMAX:
- {
- ggml_compute_forward_argmax(params, tensor);
- } break;
- case GGML_OP_COUNT_EQUAL:
- {
- ggml_compute_forward_count_equal(params, tensor);
- } break;
- case GGML_OP_REPEAT:
- {
- ggml_compute_forward_repeat(params, tensor);
- } break;
- case GGML_OP_REPEAT_BACK:
- {
- ggml_compute_forward_repeat_back(params, tensor);
- } break;
- case GGML_OP_CONCAT:
- {
- ggml_compute_forward_concat(params, tensor);
- } break;
- case GGML_OP_SILU_BACK:
- {
- ggml_compute_forward_silu_back(params, tensor);
- } break;
- case GGML_OP_NORM:
- {
- ggml_compute_forward_norm(params, tensor);
- } break;
- case GGML_OP_RMS_NORM:
- {
- ggml_compute_forward_rms_norm(params, tensor);
- } break;
- case GGML_OP_RMS_NORM_BACK:
- {
- ggml_compute_forward_rms_norm_back(params, tensor);
- } break;
- case GGML_OP_GROUP_NORM:
- {
- ggml_compute_forward_group_norm(params, tensor);
- } break;
- case GGML_OP_MUL_MAT:
- {
- ggml_compute_forward_mul_mat(params, tensor);
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- ggml_compute_forward_mul_mat_id(params, tensor);
- } break;
- case GGML_OP_OUT_PROD:
- {
- ggml_compute_forward_out_prod(params, tensor);
- } break;
- case GGML_OP_SCALE:
- {
- ggml_compute_forward_scale(params, tensor);
- } break;
- case GGML_OP_SET:
- {
- ggml_compute_forward_set(params, tensor);
- } break;
- case GGML_OP_CPY:
- {
- ggml_compute_forward_cpy(params, tensor);
- } break;
- case GGML_OP_CONT:
- {
- ggml_compute_forward_cont(params, tensor);
- } break;
- case GGML_OP_RESHAPE:
- {
- ggml_compute_forward_reshape(params, tensor);
- } break;
- case GGML_OP_VIEW:
- {
- ggml_compute_forward_view(params, tensor);
- } break;
- case GGML_OP_PERMUTE:
- {
- ggml_compute_forward_permute(params, tensor);
- } break;
- case GGML_OP_TRANSPOSE:
- {
- ggml_compute_forward_transpose(params, tensor);
- } break;
- case GGML_OP_GET_ROWS:
- {
- ggml_compute_forward_get_rows(params, tensor);
- } break;
- case GGML_OP_GET_ROWS_BACK:
- {
- ggml_compute_forward_get_rows_back(params, tensor);
- } break;
- case GGML_OP_DIAG:
- {
- ggml_compute_forward_diag(params, tensor);
- } break;
- case GGML_OP_DIAG_MASK_INF:
- {
- ggml_compute_forward_diag_mask_inf(params, tensor);
- } break;
- case GGML_OP_DIAG_MASK_ZERO:
- {
- ggml_compute_forward_diag_mask_zero(params, tensor);
- } break;
- case GGML_OP_SOFT_MAX:
- {
- ggml_compute_forward_soft_max(params, tensor);
- } break;
- case GGML_OP_SOFT_MAX_BACK:
- {
- ggml_compute_forward_soft_max_back(params, tensor);
- } break;
- case GGML_OP_ROPE:
- {
- ggml_compute_forward_rope(params, tensor);
- } break;
- case GGML_OP_ROPE_BACK:
- {
- ggml_compute_forward_rope_back(params, tensor);
- } break;
- case GGML_OP_CLAMP:
- {
- ggml_compute_forward_clamp(params, tensor);
- } break;
- case GGML_OP_CONV_TRANSPOSE_1D:
- {
- ggml_compute_forward_conv_transpose_1d(params, tensor);
- } break;
- case GGML_OP_IM2COL:
- {
- ggml_compute_forward_im2col(params, tensor);
- } break;
- case GGML_OP_IM2COL_BACK:
- {
- ggml_compute_forward_im2col_back_f32(params, tensor);
- } break;
- case GGML_OP_CONV_TRANSPOSE_2D:
- {
- ggml_compute_forward_conv_transpose_2d(params, tensor);
- } break;
- case GGML_OP_POOL_1D:
- {
- ggml_compute_forward_pool_1d(params, tensor);
- } break;
- case GGML_OP_POOL_2D:
- {
- ggml_compute_forward_pool_2d(params, tensor);
- } break;
- case GGML_OP_POOL_2D_BACK:
- {
- ggml_compute_forward_pool_2d_back(params, tensor);
- } break;
- case GGML_OP_UPSCALE:
- {
- ggml_compute_forward_upscale(params, tensor);
- } break;
- case GGML_OP_PAD:
- {
- ggml_compute_forward_pad(params, tensor);
- } break;
- case GGML_OP_ARANGE:
- {
- ggml_compute_forward_arange(params, tensor);
- } break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- {
- ggml_compute_forward_timestep_embedding(params, tensor);
- } break;
- case GGML_OP_ARGSORT:
- {
- ggml_compute_forward_argsort(params, tensor);
- } break;
- case GGML_OP_LEAKY_RELU:
- {
- ggml_compute_forward_leaky_relu(params, tensor);
- } break;
- case GGML_OP_FLASH_ATTN_EXT:
- {
- ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
- } break;
- case GGML_OP_FLASH_ATTN_BACK:
- {
- int32_t t = ggml_get_op_params_i32(tensor, 0);
- GGML_ASSERT(t == 0 || t == 1);
- bool masked = t != 0;
- ggml_compute_forward_flash_attn_back(params, masked, tensor);
- } break;
- case GGML_OP_SSM_CONV:
- {
- ggml_compute_forward_ssm_conv(params, tensor);
- } break;
- case GGML_OP_SSM_SCAN:
- {
- ggml_compute_forward_ssm_scan(params, tensor);
- } break;
- case GGML_OP_WIN_PART:
- {
- ggml_compute_forward_win_part(params, tensor);
- } break;
- case GGML_OP_WIN_UNPART:
- {
- ggml_compute_forward_win_unpart(params, tensor);
- } break;
- case GGML_OP_UNARY:
- {
- ggml_compute_forward_unary(params, tensor);
- } break;
- case GGML_OP_GET_REL_POS:
- {
- ggml_compute_forward_get_rel_pos(params, tensor);
- } break;
- case GGML_OP_ADD_REL_POS:
- {
- ggml_compute_forward_add_rel_pos(params, tensor);
- } break;
- case GGML_OP_RWKV_WKV6:
- {
- ggml_compute_forward_rwkv_wkv6(params, tensor);
- } break;
- case GGML_OP_MAP_UNARY:
- {
- ggml_unary_op_f32_t fun;
- memcpy(&fun, tensor->op_params, sizeof(fun));
- ggml_compute_forward_map_unary(params, tensor, fun);
- }
- break;
- case GGML_OP_MAP_BINARY:
- {
- ggml_binary_op_f32_t fun;
- memcpy(&fun, tensor->op_params, sizeof(fun));
- ggml_compute_forward_map_binary(params, tensor, fun);
- }
- break;
- case GGML_OP_MAP_CUSTOM1_F32:
- {
- ggml_custom1_op_f32_t fun;
- memcpy(&fun, tensor->op_params, sizeof(fun));
- ggml_compute_forward_map_custom1_f32(params, tensor, fun);
- }
- break;
- case GGML_OP_MAP_CUSTOM2_F32:
- {
- ggml_custom2_op_f32_t fun;
- memcpy(&fun, tensor->op_params, sizeof(fun));
- ggml_compute_forward_map_custom2_f32(params, tensor, fun);
- }
- break;
- case GGML_OP_MAP_CUSTOM3_F32:
- {
- ggml_custom3_op_f32_t fun;
- memcpy(&fun, tensor->op_params, sizeof(fun));
- ggml_compute_forward_map_custom3_f32(params, tensor, fun);
- }
- break;
- case GGML_OP_MAP_CUSTOM1:
- {
- ggml_compute_forward_map_custom1(params, tensor);
- }
- break;
- case GGML_OP_MAP_CUSTOM2:
- {
- ggml_compute_forward_map_custom2(params, tensor);
- }
- break;
- case GGML_OP_MAP_CUSTOM3:
- {
- ggml_compute_forward_map_custom3(params, tensor);
- }
- break;
- case GGML_OP_CROSS_ENTROPY_LOSS:
- {
- ggml_compute_forward_cross_entropy_loss(params, tensor);
- }
- break;
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- {
- ggml_compute_forward_cross_entropy_loss_back(params, tensor);
- }
- break;
- case GGML_OP_OPT_STEP_ADAMW:
- {
- ggml_compute_forward_opt_step_adamw(params, tensor);
- }
- break;
- case GGML_OP_NONE:
- {
- // nop
- } break;
- case GGML_OP_COUNT:
- {
- GGML_ABORT("fatal error");
- }
- }
-}
-
-// Android's libc implementation "bionic" does not support setting affinity
-#if defined(__gnu_linux__)
-static void set_numa_thread_affinity(int thread_n) {
- if (!ggml_is_numa()) {
- return;
- }
-
- int node_num;
- int rv;
- size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
-
- switch(g_state.numa.numa_strategy) {
- case GGML_NUMA_STRATEGY_DISTRIBUTE:
- // run thread on node_num thread_n / (threads per node)
- node_num = thread_n % g_state.numa.n_nodes;
- break;
- case GGML_NUMA_STRATEGY_ISOLATE:
- // run thread on current_node
- node_num = g_state.numa.current_node;
- break;
- case GGML_NUMA_STRATEGY_NUMACTL:
- // use the cpuset that numactl gave us
- rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
- if (rv) {
- fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
- }
- return;
- default:
- return;
- }
-
- struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
-
- cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
- CPU_ZERO_S(setsize, cpus);
- for (size_t i = 0; i < node->n_cpus; ++i) {
- CPU_SET_S(node->cpus[i], setsize, cpus);
- }
-
- rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
- if (rv) {
- fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
- }
-
- CPU_FREE(cpus);
-}
-
-static void clear_numa_thread_affinity(void) {
- if (!ggml_is_numa()) {
- return;
- }
-
- size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
-
- cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
- CPU_ZERO_S(setsize, cpus);
- for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
- CPU_SET_S(i, setsize, cpus);
- }
-
- int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
- if (rv) {
- fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
- }
-
- CPU_FREE(cpus);
-}
-#else
-// TODO: Windows etc.
-// (the linux implementation may also work on BSD, someone should test)
-static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
-static void clear_numa_thread_affinity(void) {}
-#endif
-
-static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
- int n_tasks = 0;
-
- if (ggml_is_empty(node)) {
- // no need to multi-thread a no-op
- n_tasks = 1;
- return n_tasks;
- }
-
- switch (node->op) {
- case GGML_OP_CPY:
- case GGML_OP_DUP:
- case GGML_OP_CONT:
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- case GGML_OP_ACC:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_SUB:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_LOG:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_MEAN:
- case GGML_OP_ARGMAX:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_COUNT_EQUAL:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_REPEAT:
- case GGML_OP_REPEAT_BACK:
- case GGML_OP_LEAKY_RELU:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(node)) {
- case GGML_UNARY_OP_ABS:
- case GGML_UNARY_OP_SGN:
- case GGML_UNARY_OP_NEG:
- case GGML_UNARY_OP_STEP:
- case GGML_UNARY_OP_TANH:
- case GGML_UNARY_OP_ELU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_SIGMOID:
- case GGML_UNARY_OP_HARDSWISH:
- case GGML_UNARY_OP_HARDSIGMOID:
- case GGML_UNARY_OP_EXP:
- {
- n_tasks = 1;
- } break;
-
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_SILU:
- {
- n_tasks = n_threads;
- } break;
- default:
- GGML_ABORT("fatal error");
- }
- break;
- case GGML_OP_SILU_BACK:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_RMS_NORM_BACK:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_CONCAT:
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- case GGML_OP_OUT_PROD:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_GET_ROWS:
- {
- // FIXME: get_rows can use additional threads, but the cost of launching additional threads
- // decreases performance with GPU offloading
- //n_tasks = n_threads;
- n_tasks = 1;
- } break;
- case GGML_OP_SCALE:
- case GGML_OP_SET:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_GET_ROWS_BACK:
- case GGML_OP_DIAG:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_DIAG_MASK_ZERO:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ROPE:
- case GGML_OP_ROPE_BACK:
- case GGML_OP_ADD_REL_POS:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_CLAMP:
- {
- n_tasks = 1; //TODO
- } break;
- case GGML_OP_SOFT_MAX:
- {
- n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
- } break;
- case GGML_OP_IM2COL:
- case GGML_OP_IM2COL_BACK:
- case GGML_OP_CONV_TRANSPOSE_1D:
- case GGML_OP_CONV_TRANSPOSE_2D:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_POOL_1D:
- case GGML_OP_POOL_2D:
- case GGML_OP_POOL_2D_BACK:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
- case GGML_OP_ARANGE:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_ARGSORT:
- case GGML_OP_FLASH_ATTN_EXT:
- case GGML_OP_FLASH_ATTN_BACK:
- case GGML_OP_SSM_CONV:
- case GGML_OP_SSM_SCAN:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_WIN_PART:
- case GGML_OP_WIN_UNPART:
- case GGML_OP_GET_REL_POS:
- case GGML_OP_RWKV_WKV6:
- case GGML_OP_MAP_UNARY:
- case GGML_OP_MAP_BINARY:
- case GGML_OP_MAP_CUSTOM1_F32:
- case GGML_OP_MAP_CUSTOM2_F32:
- case GGML_OP_MAP_CUSTOM3_F32:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_MAP_CUSTOM1:
- {
- struct ggml_map_custom1_op_params p;
- memcpy(&p, node->op_params, sizeof(p));
- if (p.n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p.n_tasks, n_threads);
- }
- } break;
- case GGML_OP_MAP_CUSTOM2:
- {
- struct ggml_map_custom2_op_params p;
- memcpy(&p, node->op_params, sizeof(p));
- if (p.n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p.n_tasks, n_threads);
- }
- } break;
- case GGML_OP_MAP_CUSTOM3:
- {
- struct ggml_map_custom3_op_params p;
- memcpy(&p, node->op_params, sizeof(p));
- if (p.n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p.n_tasks, n_threads);
- }
- } break;
- case GGML_OP_CROSS_ENTROPY_LOSS:
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- case GGML_OP_OPT_STEP_ADAMW:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_NONE:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_COUNT:
- {
- GGML_ABORT("fatal error");
- }
- default:
- {
- fprintf(stderr, "%s: op not implemented: ", __func__);
- if (node->op < GGML_OP_COUNT) {
- fprintf(stderr, "%s\n", ggml_op_name(node->op));
- } else {
- fprintf(stderr, "%d\n", node->op);
- }
- GGML_ABORT("fatal error");
- }
- }
-
- assert(n_tasks > 0);
-
- return n_tasks;
-}
-
-static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
-
-#if defined(_WIN32)
-#include "windows.h"
-
-// TODO: support > 64 CPUs
-bool ggml_thread_apply_affinity(bool * mask) {
- HANDLE h = GetCurrentThread();
- uint64_t bitmask = 0ULL;
-
- assert(GGML_MAX_N_THREADS >= 64);
-
- for (int32_t i = 0; i < 8; i++) {
- int32_t idx = i * 8;
- uint8_t val = 0;
- val |= mask[idx + 0] << 0;
- val |= mask[idx + 1] << 1;
- val |= mask[idx + 2] << 2;
- val |= mask[idx + 3] << 3;
- val |= mask[idx + 4] << 4;
- val |= mask[idx + 5] << 5;
- val |= mask[idx + 6] << 6;
- val |= mask[idx + 7] << 7;
- bitmask |= (uint64_t)val << idx;
- }
-
- for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {
- if (mask[i]) {
- fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
- break;
- }
- }
-
- DWORD_PTR m = (DWORD_PTR)bitmask;
-
- m = SetThreadAffinityMask(h, m);
-
- return m != 0;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
- // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
- // This is up to the applications.
- DWORD p = THREAD_PRIORITY_NORMAL;
- switch (prio) {
- case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
- case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
- case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
- case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
- }
-
- if (prio == GGML_SCHED_PRIO_NORMAL) {
- // Keep inherited policy/priority
- return true;
- }
-
- if (!SetThreadPriority(GetCurrentThread(), p)) {
- fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
- return false;
- }
-
- return true;
-}
-
-#elif defined(__APPLE__)
-#include <sys/types.h>
-#include <sys/resource.h>
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
- // Not supported on Apple platforms
- UNUSED(mask);
- return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
- struct sched_param p;
- int32_t policy = SCHED_OTHER;
- switch (prio) {
- case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
- case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
- case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
- case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
- }
-
- if (prio == GGML_SCHED_PRIO_NORMAL) {
- // Keep inherited policy/priority
- return true;
- }
-
- int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
- if (err != 0) {
- fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
- return false;
- }
-
- return true;
-}
-
-#elif defined(__gnu_linux__)
-// TODO: this may not work on BSD, to be verified
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
- cpu_set_t cpuset;
- int err;
-
- CPU_ZERO(&cpuset);
-
- for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
- if (mask[i]) {
- GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
- CPU_SET(i, &cpuset);
- }
- }
-
-#ifdef __ANDROID__
- err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
- if (err < 0) {
- err = errno;
- }
-#else
- err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
-#endif
- if (err != 0) {
- fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
- return false;
- }
-
- return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
- struct sched_param p;
- int32_t policy = SCHED_OTHER;
- switch (prio) {
- case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
- case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
- case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
- case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
- }
-
- if (prio == GGML_SCHED_PRIO_NORMAL) {
- // Keep inherited policy/priority
- return true;
- }
-
- int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
- if (err != 0) {
- fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
- return false;
- }
-
- return true;
-}
-
-#else // unsupported platforms
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
- UNUSED(mask);
- return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
- UNUSED(prio);
- return true;
-}
-
-#endif
-
-static bool ggml_thread_cpumask_is_valid(const bool * mask) {
- for (int i = 0; i < GGML_MAX_N_THREADS; i++) {
- if (mask[i]) { return true; }
- }
- return false;
-}
-
-static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
- if (!strict) {
- memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);
- return;
- } else {
- memset(local_mask, 0, GGML_MAX_N_THREADS);
- int32_t base_idx = *iter;
- for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
- int32_t idx = base_idx + i;
- if (idx >= GGML_MAX_N_THREADS) {
- // Just a cheaper modulo
- idx -= GGML_MAX_N_THREADS;
- }
- if (global_mask[idx]) {
- local_mask[idx] = 1;
- *iter = idx + 1;
- return;
- }
- }
- }
-}
-
-void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
- if (!threadpool) return;
-
- const int n_threads = threadpool->n_threads_max;
-
-#ifndef GGML_USE_OPENMP
- struct ggml_compute_state* workers = threadpool->workers;
-
- ggml_mutex_lock(&threadpool->mutex);
-
- threadpool->stop = true;
- threadpool->pause = false;
-
- ggml_cond_broadcast(&threadpool->cond);
- ggml_mutex_unlock(&threadpool->mutex);
-
- for (int j = 1; j < n_threads; j++) {
- int32_t rc = ggml_thread_join(workers[j].thrd, NULL);
- GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);
- UNUSED(rc);
- }
-
- ggml_mutex_destroy(&threadpool->mutex);
- ggml_cond_destroy(&threadpool->cond);
-#endif // GGML_USE_OPENMP
-
- const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
- ggml_aligned_free(threadpool->workers, workers_size);
- ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
-}
-
-#ifndef GGML_USE_OPENMP
-// pause/resume must be called under mutex
-static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {
- GGML_PRINT_DEBUG("Pausing threadpool\n");
- threadpool->pause = true;
- ggml_cond_broadcast(&threadpool->cond);
-}
-
-static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {
- GGML_PRINT_DEBUG("Resuming threadpool\n");
- threadpool->pause = false;
- ggml_cond_broadcast(&threadpool->cond);
-}
-#endif
-
-void ggml_threadpool_pause(struct ggml_threadpool * threadpool) {
-#ifndef GGML_USE_OPENMP
- ggml_mutex_lock(&threadpool->mutex);
- if (!threadpool->pause) {
- ggml_threadpool_pause_locked(threadpool);
- }
- ggml_mutex_unlock(&threadpool->mutex);
-#else
- UNUSED(threadpool);
-#endif
-}
-
-void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
-#ifndef GGML_USE_OPENMP
- ggml_mutex_lock(&threadpool->mutex);
- if (threadpool->pause) {
- ggml_threadpool_resume_locked(threadpool);
- }
- ggml_mutex_unlock(&threadpool->mutex);
-#else
- UNUSED(threadpool);
-#endif
-}
-
-struct ggml_cplan ggml_graph_plan(
- const struct ggml_cgraph * cgraph,
- int n_threads,
- struct ggml_threadpool * threadpool) {
-
- if (threadpool == NULL) {
- //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
- }
- if (n_threads <= 0) {
- n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
- }
-
- size_t work_size = 0;
-
- struct ggml_cplan cplan;
- memset(&cplan, 0, sizeof(struct ggml_cplan));
-
- int max_tasks = 1;
-
- // thread scheduling for the different operations + work buffer size estimation
- for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * node = cgraph->nodes[i];
-
- const int n_tasks = ggml_get_n_tasks(node, n_threads);
-
- max_tasks = MAX(max_tasks, n_tasks);
-
- size_t cur = 0;
-
- switch (node->op) {
- case GGML_OP_CPY:
- case GGML_OP_DUP:
- {
- if (ggml_is_quantized(node->type) ||
- // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
- (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
- (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
- cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
- }
- } break;
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- {
- if (ggml_is_quantized(node->src[0]->type)) {
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
- }
- } break;
- case GGML_OP_ACC:
- {
- if (ggml_is_quantized(node->src[0]->type)) {
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
- }
- } break;
- case GGML_OP_COUNT_EQUAL:
- {
- cur = ggml_type_size(node->type)*n_tasks;
- } break;
- case GGML_OP_MUL_MAT:
- {
- const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
-
- if (node->src[1]->type != vec_dot_type) {
- cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
- }
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- cur = 0;
- const struct ggml_tensor * src0 = node->src[0];
- const struct ggml_tensor * src1 = node->src[1];
- const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
- if (src1->type != vec_dot_type) {
- cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
- }
- const int n_as = src0->ne[2];
- cur += GGML_PAD(cur, sizeof(int64_t)); // align
- cur += n_as * sizeof(int64_t); // matrix_row_counts
- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
- } break;
- case GGML_OP_OUT_PROD:
- {
- if (ggml_is_quantized(node->src[0]->type)) {
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
- }
- } break;
- case GGML_OP_SOFT_MAX:
- case GGML_OP_ROPE:
- {
- cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
- } break;
- case GGML_OP_CONV_TRANSPOSE_1D:
- {
- GGML_ASSERT(node->src[0]->ne[3] == 1);
- GGML_ASSERT(node->src[1]->ne[2] == 1);
- GGML_ASSERT(node->src[1]->ne[3] == 1);
-
- const int64_t ne00 = node->src[0]->ne[0]; // K
- const int64_t ne01 = node->src[0]->ne[1]; // Cout
- const int64_t ne02 = node->src[0]->ne[2]; // Cin
-
- const int64_t ne10 = node->src[1]->ne[0]; // L
- const int64_t ne11 = node->src[1]->ne[1]; // Cin
-
- if ((node->src[0]->type == GGML_TYPE_F16 ||
- node->src[0]->type == GGML_TYPE_BF16) &&
- node->src[1]->type == GGML_TYPE_F32) {
- cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
- cur += sizeof(ggml_fp16_t)*ne10*ne11;
- } else if (node->src[0]->type == GGML_TYPE_F32 &&
- node->src[1]->type == GGML_TYPE_F32) {
- cur += sizeof(float)*ne00*ne01*ne02;
- cur += sizeof(float)*ne10*ne11;
- } else {
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_CONV_TRANSPOSE_2D:
- {
- const int64_t ne00 = node->src[0]->ne[0]; // W
- const int64_t ne01 = node->src[0]->ne[1]; // H
- const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
- const int64_t ne03 = node->src[0]->ne[3]; // Channels In
-
- const int64_t ne10 = node->src[1]->ne[0]; // W
- const int64_t ne11 = node->src[1]->ne[1]; // H
- const int64_t ne12 = node->src[1]->ne[2]; // Channels In
-
- cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
- cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
- } break;
- case GGML_OP_FLASH_ATTN_EXT:
- {
- const int64_t ne00 = node->src[0]->ne[0]; // D
-
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
- } break;
- case GGML_OP_FLASH_ATTN_BACK:
- {
- const int64_t D = node->src[0]->ne[0];
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
- if (node->src[1]->type == GGML_TYPE_F32) {
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
- } else if (node->src[1]->type == GGML_TYPE_F16) {
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
- }
- } break;
-
- case GGML_OP_CROSS_ENTROPY_LOSS:
- {
- cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
- } break;
- case GGML_OP_COUNT:
- {
- GGML_ABORT("fatal error");
- }
- default:
- break;
- }
-
- work_size = MAX(work_size, cur);
- }
-
- if (work_size > 0) {
- work_size += CACHE_LINE_SIZE*(n_threads);
- }
-
- cplan.threadpool = threadpool;
- cplan.n_threads = MIN(max_tasks, n_threads);
- cplan.work_size = work_size;
- cplan.work_data = NULL;
-
- return cplan;
-}
-
-static thread_ret_t ggml_graph_compute_thread(void * data) {
- struct ggml_compute_state * state = (struct ggml_compute_state *) data;
- struct ggml_threadpool * tp = state->threadpool;
-
- const struct ggml_cgraph * cgraph = tp->cgraph;
- const struct ggml_cplan * cplan = tp->cplan;
-
- set_numa_thread_affinity(state->ith);
-
- struct ggml_compute_params params = {
- /*.ith =*/ state->ith,
- /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
- /*.wsize =*/ cplan->work_size,
- /*.wdata =*/ cplan->work_data,
- /*.threadpool=*/ tp,
- };
-
- for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
- struct ggml_tensor * node = cgraph->nodes[node_n];
-
- ggml_compute_forward(¶ms, node);
-
- if (state->ith == 0 && cplan->abort_callback &&
- cplan->abort_callback(cplan->abort_callback_data)) {
- tp->abort = true;
- tp->ec = GGML_STATUS_ABORTED;
- }
-
- ggml_barrier(state->threadpool);
- }
-
- return 0;
-}
-
-#ifndef GGML_USE_OPENMP
-
-// check if thread is active
-static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
- struct ggml_threadpool * threadpool = state->threadpool;
- int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
- return (state->ith < n_threads);
-}
-
-// check if thread is ready to proceed (exit from polling or sleeping)
-static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
- struct ggml_threadpool * threadpool = state->threadpool;
-
- if (state->pending || threadpool->stop || threadpool->pause) { return true; }
-
- // check for new graph/work
- int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
- if (new_graph != state->last_graph) {
- state->pending = ggml_graph_compute_thread_active(state);
- state->last_graph = new_graph;
- }
-
- return state->pending;
-}
-
-// sync thread state after polling
-static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
- // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
- #ifdef GGML_TSAN_ENABLED
- atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
- #else
- atomic_thread_fence(memory_order_seq_cst);
- #endif
- UNUSED(state);
-}
-
-static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
- struct ggml_threadpool * threadpool = state->threadpool;
-
- // Skip polling for unused threads
- if (!ggml_graph_compute_thread_active(state)) {
- return state->pending;
- }
-
- // This seems to make 0 ... 100 a decent range for polling level across modern processors.
- // Perhaps, we can adjust it dynamically based on load and things.
- const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
-
- for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
- // No new work. Keep polling.
- ggml_thread_cpu_relax();
- }
-
- return state->pending;
-}
-
-static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
- struct ggml_threadpool * threadpool = state->threadpool;
-
- if (ggml_graph_compute_poll_for_work(state)) {
- ggml_graph_compute_thread_sync(state);
- return state->pending;
- }
-
- ggml_mutex_lock_shared(&threadpool->mutex);
- while (!ggml_graph_compute_thread_ready(state)) {
- // No new work. Wait for the signal.
- GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
- ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
- }
- ggml_mutex_unlock_shared(&threadpool->mutex);
-
- return state->pending;
-}
-
-static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
- struct ggml_compute_state * state = (struct ggml_compute_state *) data;
- struct ggml_threadpool * threadpool = state->threadpool;
-
- ggml_thread_apply_priority(threadpool->prio);
- if (ggml_thread_cpumask_is_valid(state->cpumask)) {
- ggml_thread_apply_affinity(state->cpumask);
- }
-
- while (true) {
- // Check if we need to sleep
- while (threadpool->pause) {
- GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
- ggml_mutex_lock_shared(&threadpool->mutex);
- if (threadpool->pause) {
- ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
- }
- GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
- ggml_mutex_unlock_shared(&threadpool->mutex);
- }
-
- // This needs to be checked for after the cond_wait
- if (threadpool->stop) break;
-
- // Check if there is new work
- // The main thread is the only one that can dispatch new work
-
- ggml_graph_compute_check_for_work(state);
- if (state->pending) {
- state->pending = false;
-
- ggml_graph_compute_thread(state);
- }
- }
-
- return (thread_ret_t) 0;
-}
-
-// Start processing new graph
-static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
-{
- // Always take the mutex here because the worker threads are doing hybrid poll/wait
-
- ggml_mutex_lock(&threadpool->mutex);
-
- GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
-
- // Update the number of active threads
- atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
-
- // Indicate the graph is ready to be processed
- // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
- atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
-
- if (threadpool->pause) {
- // Update main thread prio and affinity to match the threadpool settings
- ggml_thread_apply_priority(threadpool->prio);
- if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
- ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
- }
-
- // resume does cond broadcast
- ggml_threadpool_resume_locked(threadpool);
- } else {
- ggml_cond_broadcast(&threadpool->cond);
- }
-
- ggml_mutex_unlock(&threadpool->mutex);
-}
-
-#endif // GGML_USE_OPENMP
-
-void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
- p->n_threads = n_threads;
- p->prio = 0; // default priority (usually means normal or inherited)
- p->poll = 50; // hybrid-polling enabled
- p->strict_cpu = false; // no strict placement (all threads share same cpumask)
- p->paused = false; // threads are ready to go
- memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
-}
-
-struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
- struct ggml_threadpool_params p;
- ggml_threadpool_params_init(&p, n_threads);
- return p;
-}
-
-bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
- if (p0->n_threads != p1->n_threads ) return false;
- if (p0->prio != p1->prio ) return false;
- if (p0->poll != p1->poll ) return false;
- if (p0->strict_cpu != p1->strict_cpu ) return false;
- return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
-}
-
-static struct ggml_threadpool * ggml_threadpool_new_impl(
- struct ggml_threadpool_params * tpp,
- struct ggml_cgraph * cgraph,
- struct ggml_cplan * cplan) {
-
- struct ggml_threadpool * threadpool =
- ggml_aligned_malloc(sizeof(struct ggml_threadpool));
- {
- threadpool->cgraph = cgraph;
- threadpool->cplan = cplan;
- threadpool->n_graph = 0;
- threadpool->n_barrier = 0;
- threadpool->n_barrier_passed = 0;
- threadpool->current_chunk = 0;
- threadpool->stop = false;
- threadpool->pause = tpp->paused;
- threadpool->abort = false;
- threadpool->workers = NULL;
- threadpool->n_threads_max = tpp->n_threads;
- threadpool->n_threads_cur = tpp->n_threads;
- threadpool->poll = tpp->poll;
- threadpool->prio = tpp->prio;
- threadpool->ec = GGML_STATUS_SUCCESS;
- }
-
- // Allocate and init workers state
- const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
- struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
-
- memset(workers, 0, workers_size);
- for (int j = 0; j < tpp->n_threads; j++) {
- workers[j].threadpool = threadpool;
- workers[j].ith = j;
- }
-
- threadpool->workers = workers;
-
-#ifndef GGML_USE_OPENMP
- ggml_mutex_init(&threadpool->mutex);
- ggml_cond_init(&threadpool->cond);
-
- // Spin the threads for all workers, and update CPU placements.
- // Place the main thread last (towards the higher numbered CPU cores).
-
- int32_t cpumask_iter = 0;
-
- for (int j = 1; j < tpp->n_threads; j++) {
- ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
-
- int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);
- GGML_ASSERT(rc == 0);
- }
-
- ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
-
- if (!threadpool->pause) {
- // Update main thread prio and affinity at the start, otherwise we'll do it in resume
- ggml_thread_apply_priority(threadpool->prio);
- if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
- ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
- }
- }
-#endif // GGML_USE_OPENMP
-
- return threadpool;
-}
-
-struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {
- return ggml_threadpool_new_impl(tpp, NULL, NULL);
-}
-
-enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
- ggml_cpu_init();
-
- GGML_ASSERT(cplan);
- GGML_ASSERT(cplan->n_threads > 0);
- GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
-
- int n_threads = cplan->n_threads;
- struct ggml_threadpool * threadpool = cplan->threadpool;
-
- bool disposable_threadpool = false;
-
- if (threadpool == NULL) {
- //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
- disposable_threadpool = true;
-
- struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);
- threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);
- } else {
- // Reset some of the parameters that need resetting
- // No worker threads should be accessing the parameters below at this stage
- threadpool->cgraph = cgraph;
- threadpool->cplan = cplan;
- threadpool->current_chunk = 0;
- threadpool->abort = false;
- threadpool->ec = GGML_STATUS_SUCCESS;
- }
-
-#ifdef GGML_USE_OPENMP
- if (n_threads > 1) {
- #pragma omp parallel num_threads(n_threads)
- {
- #pragma omp single
- {
- // update the number of threads from the actual number of threads that we got from OpenMP
- n_threads = omp_get_num_threads();
- atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
- }
-
- ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
- }
- } else {
- atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
- ggml_graph_compute_thread(&threadpool->workers[0]);
- }
-#else
- if (n_threads > threadpool->n_threads_max) {
- GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
- n_threads = threadpool->n_threads_max;
- }
-
- // Kick all threads to start the new graph
- ggml_graph_compute_kickoff(threadpool, n_threads);
-
- // This is a work thread too
- ggml_graph_compute_thread(&threadpool->workers[0]);
-#endif
-
- // don't leave affinity set on the main thread
- clear_numa_thread_affinity();
-
- enum ggml_status ret = threadpool->ec;
-
- if (disposable_threadpool) {
- ggml_threadpool_free(threadpool);
- }
-
- return ret;
-}
-
-enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
- struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);
-
- cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);
-
- return ggml_graph_compute(cgraph, &cplan);
-}
-
-int ggml_cpu_has_neon(void) {
-#if defined(__ARM_ARCH)
- return ggml_arm_arch_features.has_neon;
-#else
- return 0;
-#endif
-}
-
-int ggml_cpu_has_sve(void) {
-#if defined(__ARM_ARCH)
- return ggml_arm_arch_features.has_sve;
-#else
- return 0;
-#endif
-}
-
-int ggml_cpu_has_matmul_int8(void) {
-#if defined(__ARM_ARCH)
- return ggml_arm_arch_features.has_i8mm;
-#else
- return 0;
-#endif
-}
-
-int ggml_cpu_get_sve_cnt(void) {
-#if defined(__ARM_ARCH)
- return ggml_arm_arch_features.sve_cnt;
-#else
- return 0;
-#endif
-}
-
-void ggml_cpu_init(void) {
- // needed to initialize f16 tables
- {
- struct ggml_init_params params = { 0, NULL, false };
- struct ggml_context * ctx = ggml_init(params);
- ggml_free(ctx);
- }
-
- ggml_critical_section_start();
-
- static bool is_first_call = true;
-
- if (is_first_call) {
- // initialize GELU, Quick GELU, SILU and EXP F32 tables
- {
- const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
-
- for (int i = 0; i < (1 << 16); ++i) {
- union {
- uint16_t u16;
- ggml_fp16_t fp16;
- } u = {i};
- float f = GGML_FP16_TO_FP32(u.fp16);
- ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
- ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
- }
-
- const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
-
- GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
- }
-
-#if defined(__ARM_ARCH)
- ggml_init_arm_arch_features();
-#endif
-
- is_first_call = false;
- }
-
- ggml_critical_section_end();
-}
+++ /dev/null
-#include "ggml-cuda.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-cuda/common.cuh"
-#include "ggml-cuda/acc.cuh"
-#include "ggml-cuda/arange.cuh"
-#include "ggml-cuda/argmax.cuh"
-#include "ggml-cuda/argsort.cuh"
-#include "ggml-cuda/binbcast.cuh"
-#include "ggml-cuda/clamp.cuh"
-#include "ggml-cuda/concat.cuh"
-#include "ggml-cuda/conv-transpose-1d.cuh"
-#include "ggml-cuda/convert.cuh"
-#include "ggml-cuda/count-equal.cuh"
-#include "ggml-cuda/cpy.cuh"
-#include "ggml-cuda/cross-entropy-loss.cuh"
-#include "ggml-cuda/diagmask.cuh"
-#include "ggml-cuda/dmmv.cuh"
-#include "ggml-cuda/fattn.cuh"
-#include "ggml-cuda/getrows.cuh"
-#include "ggml-cuda/im2col.cuh"
-#include "ggml-cuda/mmq.cuh"
-#include "ggml-cuda/mmvq.cuh"
-#include "ggml-cuda/norm.cuh"
-#include "ggml-cuda/opt-step-adamw.cuh"
-#include "ggml-cuda/out-prod.cuh"
-#include "ggml-cuda/pad.cuh"
-#include "ggml-cuda/pool2d.cuh"
-#include "ggml-cuda/quantize.cuh"
-#include "ggml-cuda/rope.cuh"
-#include "ggml-cuda/scale.cuh"
-#include "ggml-cuda/softmax.cuh"
-#include "ggml-cuda/sum.cuh"
-#include "ggml-cuda/sumrows.cuh"
-#include "ggml-cuda/tsembd.cuh"
-#include "ggml-cuda/unary.cuh"
-#include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/wkv6.cuh"
-
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <cinttypes>
-#include <cstddef>
-#include <cstdint>
-#include <float.h>
-#include <limits>
-#include <map>
-#include <memory>
-#include <mutex>
-#include <stdint.h>
-#include <stdio.h>
-#include <stdarg.h>
-#include <stdlib.h>
-#include <string>
-#include <vector>
-
-static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
-
-[[noreturn]]
-void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
- int id = -1; // in case cudaGetDevice fails
- cudaGetDevice(&id);
-
- GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
- GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
- GGML_LOG_ERROR(" %s\n", stmt);
- // abort with GGML_ABORT to get a stack trace
- GGML_ABORT(GGML_CUDA_NAME " error");
-}
-
-// this is faster on Windows
-// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
-void ggml_cuda_set_device(int device) {
- int current_device;
- CUDA_CHECK(cudaGetDevice(¤t_device));
-
- if (device == current_device) {
- return;
- }
-
- CUDA_CHECK(cudaSetDevice(device));
-}
-
-int ggml_cuda_get_device() {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- return id;
-}
-
-static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
- ggml_cuda_set_device(device);
-#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
- auto res = hipMallocManaged(ptr, size);
- if (res == hipSuccess) {
- // if error we "need" to know why...
- CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
- }
- return res;
-#else
-
-#if !defined(GGML_USE_HIPBLAS)
- cudaError_t err;
- if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
- {
- err = cudaMallocManaged(ptr, size);
- }
- else
- {
- err = cudaMalloc(ptr, size);
- }
- return err;
-#else
- return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIPBLAS)
-
-#endif
-}
-
-static ggml_cuda_device_info ggml_cuda_init() {
-#ifdef __HIP_PLATFORM_AMD__
- // Workaround for a rocBLAS bug when using multiple graphics cards:
- // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
- rocblas_initialize();
- CUDA_CHECK(cudaDeviceSynchronize());
-#endif
-
- ggml_cuda_device_info info = {};
-
- cudaError_t err = cudaGetDeviceCount(&info.device_count);
- if (err != cudaSuccess) {
- GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
- return info;
- }
-
- GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
-
- int64_t total_vram = 0;
-#ifdef GGML_CUDA_FORCE_MMQ
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
-#else
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
-#endif // GGML_CUDA_FORCE_MMQ
-#ifdef GGML_CUDA_FORCE_CUBLAS
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
-#else
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
-#endif // GGML_CUDA_FORCE_CUBLAS
- GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
- for (int id = 0; id < info.device_count; ++id) {
- int device_vmm = 0;
-
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
- CUdevice device;
- CU_CHECK(cuDeviceGet(&device, id));
- CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
-
- if (device_vmm) {
- CUmemAllocationProp alloc_prop = {};
- alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
- alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
- alloc_prop.location.id = id;
- CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
- }
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
- info.devices[id].vmm = !!device_vmm;
-
- cudaDeviceProp prop;
- CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-
- info.default_tensor_split[id] = total_vram;
- total_vram += prop.totalGlobalMem;
-
- info.devices[id].nsm = prop.multiProcessorCount;
- info.devices[id].smpb = prop.sharedMemPerBlock;
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- info.devices[id].smpbo = prop.sharedMemPerBlock;
- info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
-#else
- info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
- info.devices[id].cc = 100*prop.major + 10*prop.minor;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- }
-
- for (int id = 0; id < info.device_count; ++id) {
- info.default_tensor_split[id] /= total_vram;
- }
-
- // configure logging to stdout
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
-
- return info;
-}
-
-const ggml_cuda_device_info & ggml_cuda_info() {
- static ggml_cuda_device_info info = ggml_cuda_init();
- return info;
-}
-
-// #define DEBUG_CUDA_MALLOC
-
-// buffer pool for cuda (legacy)
-struct ggml_cuda_pool_leg : public ggml_cuda_pool {
- static const int MAX_BUFFERS = 256;
-
- int device;
- struct ggml_cuda_buffer {
- void * ptr = nullptr;
- size_t size = 0;
- };
-
- ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
- size_t pool_size = 0;
-
- explicit ggml_cuda_pool_leg(int device) :
- device(device) {
- }
-
- ~ggml_cuda_pool_leg() {
- ggml_cuda_set_device(device);
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cuda_buffer & b = buffer_pool[i];
- if (b.ptr != nullptr) {
- CUDA_CHECK(cudaFree(b.ptr));
- pool_size -= b.size;
- }
- }
- GGML_ASSERT(pool_size == 0);
- }
-
- void * alloc(size_t size, size_t * actual_size) override {
-#ifdef DEBUG_CUDA_MALLOC
- int nnz = 0;
- size_t max_size = 0;
-#endif
- size_t best_diff = 1ull << 36;
- int ibest = -1;
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cuda_buffer& b = buffer_pool[i];
- if (b.ptr != nullptr) {
-#ifdef DEBUG_CUDA_MALLOC
- ++nnz;
- if (b.size > max_size) max_size = b.size;
-#endif
- if (b.size >= size) {
- size_t diff = b.size - size;
- if (diff < best_diff) {
- best_diff = diff;
- ibest = i;
- if (!best_diff) {
- void * ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- }
- }
- }
- }
- if (ibest >= 0) {
- ggml_cuda_buffer& b = buffer_pool[ibest];
- void * ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- void * ptr;
- size_t look_ahead_size = (size_t) (1.05 * size);
- look_ahead_size = 256 * ((look_ahead_size + 255)/256);
- ggml_cuda_set_device(device);
- CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
- *actual_size = look_ahead_size;
- pool_size += look_ahead_size;
-#ifdef DEBUG_CUDA_MALLOC
- GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
- (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
-#endif
- return ptr;
- }
-
- void free(void * ptr, size_t size) override {
- for (int i = 0; i < MAX_BUFFERS; ++i) {
- ggml_cuda_buffer& b = buffer_pool[i];
- if (b.ptr == nullptr) {
- b.ptr = ptr;
- b.size = size;
- return;
- }
- }
- GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
- ggml_cuda_set_device(device);
- CUDA_CHECK(cudaFree(ptr));
- pool_size -= size;
- }
-};
-
-// pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
- static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
-
- int device;
- CUdeviceptr pool_addr = 0;
- size_t pool_used = 0;
- size_t pool_size = 0;
- size_t granularity;
-
- explicit ggml_cuda_pool_vmm(int device) :
- device(device),
- granularity(ggml_cuda_info().devices[device].vmm_granularity) {
- }
-
- ~ggml_cuda_pool_vmm() {
- if (pool_addr != 0) {
- CU_CHECK(cuMemUnmap(pool_addr, pool_size));
- CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
- }
- }
-
- void * alloc(size_t size, size_t * actual_size) override {
- // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
- const size_t alignment = 128;
- size = alignment * ((size + alignment - 1) / alignment);
-
- size_t avail = pool_size - pool_used;
-
- if (size > avail) {
- // round up to the next multiple of the granularity
- size_t reserve_size = size - avail;
- reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
-
- GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
-
- // allocate more physical memory
- CUmemAllocationProp prop = {};
- prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
- prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
- prop.location.id = device;
- CUmemGenericAllocationHandle handle;
- CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
-
- // reserve virtual address space (if not already reserved)
- if (pool_addr == 0) {
- CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
- }
-
- // map at the end of the pool
- CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
-
- // the memory allocation handle is no longer needed after mapping
- CU_CHECK(cuMemRelease(handle));
-
- // set access
- CUmemAccessDesc access = {};
- access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
- access.location.id = device;
- access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
- CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
-
- // add to the pool
- pool_size += reserve_size;
-
- //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
- // device, (unsigned long long) (pool_size/1024/1024),
- // (unsigned long long) (reserve_size/1024/1024));
- }
-
- GGML_ASSERT(pool_addr != 0);
-
- void * ptr = (void *) (pool_addr + pool_used);
- *actual_size = size;
- pool_used += size;
-
-#ifdef DEBUG_CUDA_MALLOC
- printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
-#endif
-
- return ptr;
- }
-
- void free(void * ptr, size_t size) override {
-#ifdef DEBUG_CUDA_MALLOC
- printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
-#endif
-
- pool_used -= size;
-
- // all deallocations must be in reverse order of the allocations
- GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
- }
-};
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-
-std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
- if (ggml_cuda_info().devices[device].vmm) {
- return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
- }
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
- return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
-}
-
-// cuda buffer
-
-struct ggml_backend_cuda_buffer_context {
- int device;
- void * dev_ptr = nullptr;
- std::string name;
-
- ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
- device(device), dev_ptr(dev_ptr),
- name(GGML_CUDA_NAME + std::to_string(device)) {
- }
-
- ~ggml_backend_cuda_buffer_context() {
- CUDA_CHECK(cudaFree(dev_ptr));
- }
-};
-
-static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
- delete ctx;
-}
-
-static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
- return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
-}
-
-static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
- return ctx->dev_ptr;
-}
-
-static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
- if (tensor->view_src != NULL) {
- assert(tensor->view_src->buffer->buft == buffer->buft);
- return;
- }
-
- if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
- // initialize padding to 0 to avoid possible NaN values
- size_t original_size = ggml_nbytes(tensor);
- size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
- if (padded_size > original_size) {
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
- }
- }
-}
-
-static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
- if (ggml_backend_buffer_is_cuda(src->buffer)) {
- ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
- ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
- if (src_ctx->device == dst_ctx->device) {
- CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
- } else {
-#ifdef GGML_CUDA_NO_PEER_COPY
- return false;
-#else
- CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
-#endif
- }
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
- return true;
- }
- return false;
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaDeviceSynchronize());
- CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
- CUDA_CHECK(cudaDeviceSynchronize());
-}
-
-static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
- /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
- /* .get_base = */ ggml_backend_cuda_buffer_get_base,
- /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
- /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
- /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_cuda_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// cuda buffer type
-struct ggml_backend_cuda_buffer_type_context {
- int device;
- std::string name;
-};
-
-static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
-
- return ctx->name.c_str();
-}
-
-static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
-
- ggml_cuda_set_device(buft_ctx->device);
-
- void * dev_ptr;
- cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
- if (err != cudaSuccess) {
- // clear the error
- cudaGetLastError();
- GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
- return nullptr;
- }
-
- ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
-
- return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return 128;
-
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- size_t size = ggml_nbytes(tensor);
- int64_t ne0 = tensor->ne[0];
-
- if (ggml_is_quantized(tensor->type)) {
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
- }
-
- return size;
-
- GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
- /* .get_name = */ ggml_backend_cuda_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
- /* .is_host = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- if (device >= ggml_backend_cuda_get_device_count()) {
- return nullptr;
- }
-
- static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
-
- static bool ggml_backend_cuda_buffer_type_initialized = false;
-
- if (!ggml_backend_cuda_buffer_type_initialized) {
- for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {
- ggml_backend_cuda_buffer_types[i] = {
- /* .iface = */ ggml_backend_cuda_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),
- /* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
- };
- }
- ggml_backend_cuda_buffer_type_initialized = true;
- }
-
- return &ggml_backend_cuda_buffer_types[device];
-}
-
-// cuda split buffer
-
-static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
- int64_t row_rounding = 0;
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
- continue;
- }
-
- const int cc = ggml_cuda_info().devices[id].cc;
- row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
- }
- return row_rounding;
-}
-
-static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
- const int64_t nrows = ggml_nrows(tensor);
- const int64_t rounding = get_row_rounding(tensor_split);
-
- *row_low = id == 0 ? 0 : nrows*tensor_split[id];
- *row_low -= *row_low % rounding;
-
- if (id == ggml_backend_cuda_get_device_count() - 1) {
- *row_high = nrows;
- } else {
- *row_high = nrows*tensor_split[id + 1];
- *row_high -= *row_high % rounding;
- }
-}
-
-static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
- static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
-
- return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
-}
-
-struct ggml_backend_cuda_split_buffer_type_context {
- int main_device;
- std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
- std::string name;
-};
-
-struct ggml_backend_cuda_split_buffer_context {
- ~ggml_backend_cuda_split_buffer_context() {
- for (ggml_tensor_extra_gpu * extra : tensor_extras) {
- for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
- for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
- if (extra->events[id][is] != nullptr) {
- CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
- }
- }
- if (extra->data_device[id] != nullptr) {
- CUDA_CHECK(cudaFree(extra->data_device[id]));
- }
- }
- delete extra;
- }
- }
-
- std::vector<ggml_tensor_extra_gpu *> tensor_extras;
-};
-
-
-static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
- delete ctx;
-}
-
-static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
- // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
- return (void *)0x1000;
-
- GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
- GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
-
- ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
- ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
-
- ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
- ctx->tensor_extras.push_back(extra);
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- // FIXME: do not crash if cudaMalloc fails
- // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
- ggml_cuda_set_device(id);
- char * buf;
- CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
-
- // set padding to 0 to avoid possible NaN values
- if (size > original_size) {
- CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
- }
-
- extra->data_device[id] = buf;
-
- for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
- CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
- }
- }
- tensor->extra = extra;
-}
-
-static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- // split tensors must always be set in their entirety at once
- GGML_ASSERT(offset == 0);
- GGML_ASSERT(size == ggml_nbytes(tensor));
-
- ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
- const size_t nb1 = tensor->nb[1];
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- const size_t offset_split = row_low*nb1;
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- const char * buf_host = (const char *)data + offset_split;
- CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
- }
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
- }
-}
-
-static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- // split tensors must always be set in their entirety at once
- GGML_ASSERT(offset == 0);
- GGML_ASSERT(size == ggml_nbytes(tensor));
-
- ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
- const size_t nb1 = tensor->nb[1];
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- const size_t offset_split = row_low*nb1;
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- char * buf_host = (char *)data + offset_split;
- CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
- }
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
- }
-}
-
-static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- GGML_UNUSED(buffer);
- GGML_UNUSED(value);
-}
-
-static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
- /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
- /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
- /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
- /* .cpy_tensor = */ NULL,
- /* .clear = */ ggml_backend_cuda_split_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// cuda split buffer type
-
-static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
-
- return ctx->name.c_str();
-}
-
-static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
- // instead, we allocate them for each tensor separately in init_tensor
- // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
- // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
- ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
-
- return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return 128;
-
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
-
- size_t total_size = 0;
-
- const int64_t ne0 = tensor->ne[0];
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- total_size += ggml_nbytes_split(tensor, nrows_split);
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
- }
-
- return total_size;
-}
-
-static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return false;
-
- GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
- /* .get_name = */ ggml_backend_cuda_split_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
- /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
-};
-
-ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;
-
- std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
-
- bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
- if (all_zero) {
- tensor_split_arr = ggml_cuda_info().default_tensor_split;
- } else {
- float split_sum = 0.0f;
- for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
- tensor_split_arr[i] = split_sum;
- split_sum += tensor_split[i];
- }
- for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
- tensor_split_arr[i] /= split_sum;
- }
- }
-
- auto it = buft_map.find({main_device, tensor_split_arr});
- if (it != buft_map.end()) {
- return &it->second;
- }
- auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
- main_device,
- tensor_split_arr,
- GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
- };
-
- struct ggml_backend_buffer_type buft {
- /* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
- /* .context = */ ctx,
- };
-
- auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
- return &result.first->second;
-}
-
-// host buffer type
-
-static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
- return GGML_CUDA_NAME "_Host";
-
- GGML_UNUSED(buft);
-}
-
-static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- CUDA_CHECK(cudaFreeHost(buffer->context));
-}
-
-static void * ggml_cuda_host_malloc(size_t size) {
- if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
- return nullptr;
- }
-
- void * ptr = nullptr;
- cudaError_t err = cudaMallocHost((void **) &ptr, size);
- if (err != cudaSuccess) {
- // clear the error
- cudaGetLastError();
- GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
- size / 1024.0 / 1024.0, cudaGetErrorString(err));
- return nullptr;
- }
-
- return ptr;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- void * ptr = ggml_cuda_host_malloc(size);
-
- if (ptr == nullptr) {
- // fallback to cpu buffer
- return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
- }
-
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
- buffer->buft = buft;
- buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
-
- return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
- static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_cuda_host_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
- },
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
- /* .context = */ nullptr,
- };
-
- return &ggml_backend_cuda_buffer_type_host;
-}
-
-//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
-// return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
-//}
-
-/// kernels
-
-typedef void (*ggml_cuda_op_mul_mat_t)(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
- const int64_t src1_padded_row_size, cudaStream_t stream);
-
-#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
-#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
-#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
-
-#define MUL_MAT_SRC1_COL_STRIDE 128
-
-static __global__ void mul_mat_p021_f16_f32(
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
-
- const half * x = (const half *) vx;
-
- const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
- const int channel = blockDim.z*blockIdx.z + threadIdx.z;
- const int channel_x = channel / (nchannels_y / nchannels_x);
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
- const int col_x = col_x0 + threadIdx.x;
-
- if (col_x >= ncols_x) {
- break;
- }
-
- // x is transposed and permuted
- const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
- const float xi = __half2float(x[ix]);
-
- const int row_y = col_x;
-
- // y is not transposed but permuted
- const int iy = channel*nrows_y + row_y;
-
- tmp += xi * y[iy];
- }
-
- // dst is not transposed and not permuted
- const int idst = channel*nrows_dst + row_dst;
-
- // sum up partial sums and write back result
- tmp = warp_reduce_sum(tmp);
-
- if (threadIdx.x == 0) {
- dst[idst] = tmp;
- }
-}
-
-static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
-
- const half * x = (const half *) vx;
-
- const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
- const int channel = blockDim.z*blockIdx.z + threadIdx.z;
- const int channel_x = channel / channel_x_divisor;
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- const int idst = channel*nrows_dst + row_dst;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
- const int col_x = col_x0 + threadIdx.x;
-
- if (col_x >= ncols_x) {
- break;
- }
-
- const int row_y = col_x;
-
- const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
- const int iy = channel*nrows_y + row_y;
-
- const float xi = __half2float(x[ix]);
-
- tmp += xi * y[iy];
- }
-
- // sum up partial sums and write back result
- tmp = warp_reduce_sum(tmp);
-
- if (threadIdx.x == 0) {
- dst[idst] = tmp;
- }
-}
-
-static void ggml_mul_mat_p021_f16_f32_cuda(
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
- const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
-
- const dim3 block_nums(1, nrows_x, nchannels_y);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_cuda(
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
- const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
-
- const dim3 block_nums(1, nrows_x, nchannels_y);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
- (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
-}
-
-static cudaError_t ggml_cuda_cpy_tensor_2d(
- void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
-
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
- const char * src_ptr = (const char *) src->data;
- char * dst_ptr = (char *) dst;
-
- const int64_t ne0 = src->ne[0];
- const int64_t nb0 = src->nb[0];
- const int64_t nb1 = src->nb[1];
- const int64_t nb2 = src->nb[2];
- const int64_t nb3 = src->nb[3];
- const enum ggml_type type = src->type;
- const int64_t ts = ggml_type_size(type);
- const int64_t bs = ggml_blck_size(type);
- const int64_t i1_diff = i1_high - i1_low;
-
- const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*ne0/bs) {
- return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
- } else if (nb0 == ts) {
- return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
- } else {
- for (int64_t i1 = 0; i1 < i1_diff; i1++) {
- const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
- // pretend the row is a matrix with cols=1
- cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
- if (r != cudaSuccess) {
- return r;
- }
- }
- return cudaSuccess;
- }
-}
-
-static void ggml_cuda_op_mul_mat_cublas(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
- const int64_t src1_padded_row_size, cudaStream_t stream) {
-
- GGML_ASSERT(src0_dd_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_dd_i != nullptr);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne10 = src1->ne[0];
-
- const int64_t ne0 = dst->ne[0];
-
- const int64_t row_diff = row_high - row_low;
-
- int id = ggml_cuda_get_device();
-
- // the main device has a larger memory buffer to hold the results from all GPUs
- // ldc == nrows of the matrix that cuBLAS writes into
- int64_t ldc = id == ctx.device ? ne0 : row_diff;
-
- const int compute_capability = ggml_cuda_info().devices[id].cc;
-
- if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
- // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
- ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
- if (src0->type != GGML_TYPE_F16) {
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
- GGML_ASSERT(to_fp16_cuda != nullptr);
- size_t ne = row_diff*ne00;
- src0_as_f16.alloc(ne);
- to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
- }
- const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
-
- ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
- if (src1->type != GGML_TYPE_F16) {
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
- GGML_ASSERT(to_fp16_cuda != nullptr);
- size_t ne = src1_ncols*ne10;
- src1_as_f16.alloc(ne);
- to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
- }
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
-
- const half alpha_f16 = 1.0f;
- const half beta_f16 = 0.0f;
-
- CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
- CUBLAS_CHECK(
- cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
- row_diff, src1_ncols, ne10,
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
- src1_ptr, CUDA_R_16F, ne10,
- &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
- CUBLAS_COMPUTE_16F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
- } else {
- ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
- ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
-
- if (src0->type != GGML_TYPE_F32) {
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
- GGML_ASSERT(to_fp32_cuda != nullptr);
- src0_ddq_as_f32.alloc(row_diff*ne00);
- to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
- }
- if (src1->type != GGML_TYPE_F32) {
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
- GGML_ASSERT(to_fp32_cuda != nullptr);
- src1_ddq_as_f32.alloc(src1_ncols*ne10);
- to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
- }
-
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
- const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
-
- const float alpha = 1.0f;
- const float beta = 0.0f;
-
- CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
- CUBLAS_CHECK(
- cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
- row_diff, src1_ncols, ne10,
- &alpha, src0_ddf_i, ne00,
- src1_ddf1_i, ne10,
- &beta, dst_dd_i, ldc));
- }
-
- GGML_UNUSED(dst);
- GGML_UNUSED(src1_ddq_i);
- GGML_UNUSED(src1_padded_row_size);
-}
-
-static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
- static bool peer_access_enabled = false;
-
- const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
-
- if (peer_access_enabled == enable_peer_access) {
- return;
- }
-
-#ifdef NDEBUG
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- ggml_cuda_set_device(id);
- CUDA_CHECK(cudaDeviceSynchronize());
- }
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- ggml_cuda_set_device(id);
-
- for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
- if (id == id_other) {
- continue;
- }
- if (id != main_device && id_other != main_device) {
- continue;
- }
-
- int can_access_peer;
- CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
- if (can_access_peer) {
- if (enable_peer_access) {
- cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
- if (err != cudaErrorPeerAccessAlreadyEnabled) {
- CUDA_CHECK(err);
- } else {
- // reset the error
- cudaGetLastError();
- }
- } else {
- cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
- if (err != cudaErrorPeerAccessNotEnabled) {
- CUDA_CHECK(err);
- } else {
- // reset the error
- cudaGetLastError();
- }
- }
- }
- }
- }
-
- ggml_cuda_set_device(main_device);
-#endif // NDEBUG
-
- peer_access_enabled = enable_peer_access;
-
- GGML_UNUSED(main_device);
-}
-
-static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
- void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
-
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
- // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
- cudaMemcpy3DPeerParms p = {};
- p.dstDevice = dstDevice;
- p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
- p.srcDevice = srcDevice;
- p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
- p.extent = make_cudaExtent(width, height, 1);
- return cudaMemcpy3DPeerAsync(&p, stream);
-#else
- // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
- GGML_UNUSED(dstDevice);
- GGML_UNUSED(srcDevice);
- return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
-}
-
-static void ggml_cuda_op_mul_mat(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
- quantize_cuda_t quantize_src1) {
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
- const int64_t ne03 = src0->ne[3];
-
- const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
- const int64_t ne12 = src1->ne[2];
- const int64_t ne13 = src1->ne[3];
- const int64_t nrows1 = ggml_nrows(src1);
-
- GGML_ASSERT(ne03 == ne13);
-
- const int64_t ne0 = dst->ne[0];
- const int64_t ne1 = dst->ne[1];
-
- const int64_t nb2 = dst->nb[2];
- const int64_t nb3 = dst->nb[3];
-
- GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
- ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
- ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
-
- GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
-
- GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
-
- const int64_t i02_divisor = ne12 / ne02;
-
- const size_t src0_ts = ggml_type_size(src0->type);
- const size_t src0_bs = ggml_blck_size(src0->type);
- const size_t q8_1_ts = sizeof(block_q8_1);
- const size_t q8_1_bs = QK8_1;
-
- const bool src0_is_contiguous = ggml_is_contiguous(src0);
- const bool src1_is_contiguous = ggml_is_contiguous(src1);
-
- const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
-
- const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
- GGML_ASSERT(!(split && ne02 > 1));
- GGML_ASSERT(!(split && ne03 > 1));
- GGML_ASSERT(!(split && ne02 < ne12));
-
- ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
-
-
- std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
- if (split) {
- ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
- tensor_split = buft_ctx->tensor_split;
- }
-
- struct dev_data {
- int cc;
-
- ggml_cuda_pool_alloc<char> src0_dd_alloc;
- ggml_cuda_pool_alloc<float> src1_ddf_alloc;
- ggml_cuda_pool_alloc<char> src1_ddq_alloc;
- ggml_cuda_pool_alloc<float> dst_dd_alloc;
-
- char * src0_dd = nullptr;
- float * src1_ddf = nullptr; // float
- char * src1_ddq = nullptr; // q8_1
- float * dst_dd = nullptr;
-
- int64_t row_low;
- int64_t row_high;
- };
-
- dev_data dev[GGML_CUDA_MAX_DEVICES];
-
- int used_devices = 0;
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- dev[id].cc = ggml_cuda_info().devices[id].cc;
-
- // by default, use all rows
- dev[id].row_low = 0;
- dev[id].row_high = ne01;
-
- // for multi GPU, get the row boundaries from tensor split
- // and round to mul_mat_q tile sizes
- if (split) {
- const int64_t rounding = get_row_rounding(tensor_split);
-
- if (id != 0) {
- dev[id].row_low = ne01*tensor_split[id];
- if (dev[id].row_low < ne01) {
- dev[id].row_low -= dev[id].row_low % rounding;
- }
- }
-
- if (id != ggml_backend_cuda_get_device_count() - 1) {
- dev[id].row_high = ne01*tensor_split[id + 1];
- if (dev[id].row_high < ne01) {
- dev[id].row_high -= dev[id].row_high % rounding;
- }
- }
- }
- }
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
- continue;
- }
-
- used_devices++;
-
- const bool src1_on_device = id == src1_ctx->device;
- const bool dst_on_device = id == dst_ctx->device;
-
- ggml_cuda_set_device(id);
- cudaStream_t stream = ctx.stream(id, 0);
-
- if (src0_is_contiguous) {
- dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
- } else {
- // If src0 is not contiguous it will be copied to a temporary buffer.
- // This buffer needs to be cleared entirely because multiple regions will function as padding.
- const size_t nbytes_data = ggml_nbytes(src0);
- const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
- dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
- // TODO: remove this for MUSA once the Guilty Lockup issue is resolved
-#ifndef GGML_USE_MUSA
- CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
-#else // GGML_USE_MUSA
- CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
-#endif // !GGML_USE_MUSA
- }
-
- // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
- if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
- const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
- const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
- CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
- }
-
- if (src1_on_device && src1_is_contiguous) {
- dev[id].src1_ddf = (float *) src1->data;
- } else {
- dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
- }
-
- if (quantize_src1) {
- size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
- if (quantize_src1 == quantize_mmq_q8_1_cuda) {
- src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
- }
- dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
-
- if (src1_on_device && src1_is_contiguous) {
- quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
- CUDA_CHECK(cudaGetLastError());
- }
- }
-
- if (dst_on_device) {
- dev[id].dst_dd = (float *) dst->data;
- } else {
- const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
- dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
- }
- }
-
- // if multiple devices are used they need to wait for the main device
- // here an event is recorded that signals that the main device has finished calculating the input data
- if (split && used_devices > 1) {
- ggml_cuda_set_device(ctx.device);
- CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
- }
-
- const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
- for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
- const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
- const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
- continue;
- }
-
- const bool src1_on_device = id == src1_ctx->device;
- const bool dst_on_device = id == dst_ctx->device;
- const int64_t row_diff = dev[id].row_high - dev[id].row_low;
-
- ggml_cuda_set_device(id);
- cudaStream_t stream = ctx.stream(id, is);
-
- // wait for main GPU data if necessary
- if (split && (id != ctx.device || is != 0)) {
- CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
- }
-
- for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
- const int64_t i03 = i0 / ne12;
- const int64_t i02 = i0 % ne12;
-
- size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
- if (quantize_src1 == quantize_mmq_q8_1_cuda) {
- src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
- } else {
- src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
- }
-
- // for split tensors the data begins at i0 == i0_offset_low
- char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
- float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
- char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
- float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
-
- // the main device memory buffer can be on VRAM scratch, with space for all partial results
- // in that case an offset on dst_ddf_i is needed
- if (id == ctx.device) {
- dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
- }
-
- // copy src0, src1 to device if necessary
- if (src1_is_contiguous) {
- if (id != ctx.device) {
- if (quantize_src1) {
- char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
- if (quantize_src1 == quantize_mmq_q8_1_cuda) {
- const size_t pitch = ne11*sizeof(block_q8_1_mmq);
- const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
- const size_t height = src1_padded_col_size/(4*QK8_1);
- CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
- } else {
- CUDA_CHECK(cudaMemcpyPeerAsync(
- src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
- }
- } else {
- float * src1_ddf_i_source = (float *) src1->data;
- src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
- CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
- src1_ncols*ne10*sizeof(float), stream));
- }
- }
- } else if (src1_on_device && !src1_is_contiguous) {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
- src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
- } else {
- GGML_ABORT("fatal error");
- }
-
- if (quantize_src1 && !src1_is_contiguous) {
- quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
- CUDA_CHECK(cudaGetLastError());
- }
-
- if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
- }
-
- // do the computation
- op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
- dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
- CUDA_CHECK(cudaGetLastError());
-
- // copy dst to host or other device if necessary
- if (!dst_on_device) {
- void * dst_off_device = dst->data;
- if (split) {
- // src0 = weight matrix is saved as a transposed matrix for better memory layout.
- // dst is NOT transposed.
- // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
- // Instead they need to be copied to the correct slice in ne0 = dst row index.
- // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
- float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
- GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
- dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
- CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
- dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
- } else {
- float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
- GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
- dhf_dst_i += src1_col_0*ne0;
- CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
- }
- }
-
- // add event for the main device to wait on until other device is done
- if (split && (id != ctx.device || is != 0)) {
- CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
- }
- }
- }
- }
-
- // main device waits for all other devices to be finished
- if (split && ggml_backend_cuda_get_device_count() > 1) {
- int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
- is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
-
- ggml_cuda_set_device(ctx.device);
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- if (dev[id].row_low == dev[id].row_high) {
- continue;
- }
- for (int64_t is = 0; is < is_max; ++is) {
- CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
- }
- }
- }
-}
-
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
- GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
- GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t ne12 = src1->ne[2];
-
- cudaStream_t main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(!ggml_is_permuted(src0));
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t nb01 = src0->nb[1];
- const int64_t nb02 = src0->nb[2];
-
- const int64_t ne12 = src1->ne[2];
-
- cudaStream_t main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- const int64_t row_stride_x = nb01 / sizeof(half);
- const int64_t channel_stride_x = nb02 / sizeof(half);
-
- ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-
-static __global__ void k_compute_batched_ptrs(
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
- const void ** ptrs_src, void ** ptrs_dst,
- int64_t ne12, int64_t ne13,
- int64_t ne23,
- size_t nb02, size_t nb03,
- size_t nb12, size_t nb13,
- size_t nbd2, size_t nbd3,
- int64_t r2, int64_t r3) {
- int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
- int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
-
- if (i13 >= ne13 || i12 >= ne12) {
- return;
- }
-
- int64_t i03 = i13 / r3;
- int64_t i02 = i12 / r2;
-
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
-}
-
-static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
-
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t ne_dst = ggml_nelements(dst);
-
- cudaStream_t main_stream = ctx.stream();
-
- CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
-
- void * src0_ddq = src0->data;
- half * src0_f16 = (half *) src0_ddq;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- // convert src1 to fp16
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
- if (src1->type != GGML_TYPE_F16) {
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
- const int64_t ne_src1 = ggml_nelements(src1);
- src1_f16_alloc.alloc(ne_src1);
- GGML_ASSERT(to_fp16_cuda != nullptr);
- to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
- }
- half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
-
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
- char * dst_t;
-
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
- cudaDataType_t cu_data_type = CUDA_R_16F;
-
- // dst strides
- size_t nbd2 = dst->nb[2];
- size_t nbd3 = dst->nb[3];
-
- const half alpha_f16 = 1.0f;
- const half beta_f16 = 0.0f;
-
- const float alpha_f32 = 1.0f;
- const float beta_f32 = 0.0f;
-
- const void * alpha = &alpha_f16;
- const void * beta = &beta_f16;
-
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
- dst_t = (char *) dst_f16.alloc(ne_dst);
-
- nbd2 /= sizeof(float) / sizeof(half);
- nbd3 /= sizeof(float) / sizeof(half);
- } else {
- dst_t = (char *) dst_ddf;
-
- cu_compute_type = CUBLAS_COMPUTE_32F;
- cu_data_type = CUDA_R_32F;
-
- alpha = &alpha_f32;
- beta = &beta_f32;
- }
-
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- // broadcast factors
- const int64_t r2 = ne12/ne02;
- const int64_t r3 = ne13/ne03;
-
-#if 0
- // use cublasGemmEx
- {
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- int i03 = i13 / r3;
- int i02 = i12 / r2;
-
- CUBLAS_CHECK(
- cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
- (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
- beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
- cu_compute_type,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- }
- }
- }
-#else
-#ifdef GGML_USE_MUSA
- GGML_ASSERT(false);
-#else // !GGML_USE_MUSA
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
- // use cublasGemmStridedBatchedEx
- CUBLAS_CHECK(
- cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
- (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
- beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
- ne12*ne13,
- cu_compute_type,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- } else {
- // use cublasGemmBatchedEx
- const int ne23 = ne12*ne13;
-
- ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
- ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
-
- dim3 block_dims(ne13, ne12);
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
- src0_f16, src1_f16, dst_t,
- ptrs_src.get(), ptrs_dst.get(),
- ne12, ne13,
- ne23,
- nb02, nb03,
- src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
- src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
- nbd2, nbd3,
- r2, r3);
- CUDA_CHECK(cudaGetLastError());
-
- CUBLAS_CHECK(
- cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
- ne23,
- cu_compute_type,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- }
-#endif // GGML_USE_MUSA
-#endif
-
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
- }
-}
-
-static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
-
- bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
- bool use_mul_mat_q = ggml_is_quantized(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
- // if mmvq is available it's a better choice than dmmv:
-#ifndef GGML_CUDA_FORCE_DMMV
- use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-#endif // GGML_CUDA_FORCE_DMMV
-
- bool any_gpus_with_slow_fp16 = false;
-
- if (split) {
- ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
- auto & tensor_split = buft_ctx->tensor_split;
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
- // skip devices that are not going to do any work:
- if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
- continue;
- }
-
- const int cc = ggml_cuda_info().devices[id].cc;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
- }
- } else {
- const int cc = ggml_cuda_info().devices[ctx.device].cc;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
- }
-
- // debug helpers
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
-
- if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
- // FP32 precision KQ single-batch for batch size 1 without FlashAttention
- ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
- } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
- // FP32 precision KQV single-batch for batch size 1 without FlashAttention
- ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
- && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
- // KQ + KQV multi-batch without FlashAttention
- ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
- } else if (use_dequantize_mul_mat_vec) {
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
- } else if (use_mul_mat_vec_q) {
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
- } else if (use_mul_mat_q) {
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
- } else {
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
- }
-}
-
-struct mmid_row_mapping {
- int32_t i1;
- int32_t i2;
-};
-
-static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
- int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
- const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
- int64_t ne11, int64_t ne10,
- size_t nb11, size_t nb12) {
- int32_t iid1 = blockIdx.x;
- int32_t id = blockIdx.y;
-
- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
-
- if (row_id_i != i02) {
- return;
- }
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = iid1;
-
- __shared__ int src1_row;
- if (threadIdx.x == 0) {
- src1_row = atomicAdd(cur_src1_row, 1);
- row_mapping[src1_row] = {id, iid1};
- }
- __syncthreads();
-
- const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
- float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
-
- for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
- src1_row_contiguous[i] = src1_row_original[i];
- }
-}
-
-static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
- const mmid_row_mapping * __restrict__ row_mapping,
- int64_t ne0,
- size_t nb1, size_t nb2) {
- int32_t i = blockIdx.x;
-
- const int32_t i1 = row_mapping[i].i1;
- const int32_t i2 = row_mapping[i].i2;
-
- const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
- float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
-
- for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
- dst_row_original[j] = dst_row_contiguous[j];
- }
-}
-
-static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const ggml_tensor * src1 = dst->src[1];
- const ggml_tensor * ids = dst->src[2];
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
-
- cudaStream_t stream = ctx.stream();
-
- const int64_t n_as = ne02;
- const int64_t n_ids = ids->ne[0];
-
- std::vector<char> ids_host(ggml_nbytes(ids));
- const char * ids_dev = (const char *) ids->data;
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- ggml_tensor src0_row = *src0;
- ggml_tensor src1_row = *src1;
- ggml_tensor dst_row = *dst;
-
- char * src0_original = (char *) src0->data;
- char * src1_original = (char *) src1->data;
- char * dst_original = (char *) dst->data;
-
- src0_row.ne[2] = 1;
- src0_row.ne[3] = 1;
- src0_row.nb[3] = nb02;
-
- src1_row.ne[1] = 1;
- src1_row.ne[2] = 1;
- src1_row.ne[3] = 1;
- src1_row.nb[2] = nb11;
- src1_row.nb[3] = nb11;
-
- dst_row.ne[1] = 1;
- dst_row.ne[2] = 1;
- dst_row.ne[3] = 1;
- dst_row.nb[2] = nb1;
- dst_row.nb[3] = nb1;
-
- if (ne12 == 1) {
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = iid1;
-
- const int64_t i1 = id;
- const int64_t i2 = i12;
-
- src0_row.data = src0_original + i02*nb02;
- src1_row.data = src1_original + i11*nb11 + i12*nb12;
- dst_row.data = dst_original + i1*nb1 + i2*nb2;
-
- ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
- }
- }
- } else {
- ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
- ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
-
- src1_row.data = src1_contiguous.get();
- dst_row.data = dst_contiguous.get();
-
- for (int64_t i02 = 0; i02 < n_as; i02++) {
- int64_t num_src1_rows = 0;
-
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
-
- if (row_id_i != i02) {
- continue;
- }
-
- num_src1_rows++;
- }
- }
-
- if (num_src1_rows == 0) {
- continue;
- }
-
- ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
- ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
- CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
-
- {
- dim3 block_dims(std::min((unsigned int)ne10, 768u));
- dim3 grid_dims(ids->ne[1], n_ids);
- k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
- src1_original, src1_contiguous.get(),
- dev_cur_src1_row.get(), dev_row_mapping.get(),
- ids_dev, i02, ids->nb[1], ids->nb[0],
- ne11, ne10,
- nb11, nb12);
- CUDA_CHECK(cudaGetLastError());
- }
-
- src0_row.data = src0_original + i02*nb02;
-
- GGML_ASSERT(nb11 == sizeof(float)*ne10);
- GGML_ASSERT(nb1 == sizeof(float)*ne0);
-
- src1_row.ne[1] = num_src1_rows;
- src1_row.nb[1] = nb11;
- src1_row.nb[2] = num_src1_rows*nb11;
- src1_row.nb[3] = num_src1_rows*nb11;
-
- dst_row.ne[1] = num_src1_rows;
- dst_row.nb[1] = nb1;
- dst_row.nb[2] = num_src1_rows*nb1;
- dst_row.nb[3] = num_src1_rows*nb1;
-
- ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-
- {
- dim3 block_dims(std::min((unsigned int)ne0, 768u));
- dim3 grid_dims(num_src1_rows);
- k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
- dst_original, dst_contiguous.get(),
- dev_row_mapping.get(),
- ne0,
- nb1, nb2);
- CUDA_CHECK(cudaGetLastError());
- }
- }
- }
-}
-
-static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
- // why is this here instead of mul_mat?
- if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
- ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
- }
-
- switch (dst->op) {
- case GGML_OP_ARGMAX:
- ggml_cuda_argmax(ctx, dst);
- break;
- case GGML_OP_COUNT_EQUAL:
- ggml_cuda_count_equal(ctx, dst);
- break;
- case GGML_OP_REPEAT:
- ggml_cuda_op_repeat(ctx, dst);
- break;
- case GGML_OP_REPEAT_BACK:
- ggml_cuda_op_repeat_back(ctx, dst);
- break;
- case GGML_OP_GET_ROWS:
- ggml_cuda_op_get_rows(ctx, dst);
- break;
- case GGML_OP_DUP:
- ggml_cuda_dup(ctx, dst);
- break;
- case GGML_OP_CPY:
- ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
- break;
- case GGML_OP_CONT:
- ggml_cuda_dup(ctx, dst);
- break;
- case GGML_OP_ADD:
- case GGML_OP_ADD1: // TODO: more efficient implementation
- ggml_cuda_op_add(ctx, dst);
- break;
- case GGML_OP_SUB:
- ggml_cuda_op_sub(ctx, dst);
- break;
- case GGML_OP_ACC:
- ggml_cuda_op_acc(ctx, dst);
- break;
- case GGML_OP_MUL:
- ggml_cuda_op_mul(ctx, dst);
- break;
- case GGML_OP_DIV:
- ggml_cuda_op_div(ctx, dst);
- break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(dst)) {
- case GGML_UNARY_OP_NEG:
- ggml_cuda_op_neg(ctx, dst);
- break;
- case GGML_UNARY_OP_STEP:
- ggml_cuda_op_step(ctx, dst);
- break;
- case GGML_UNARY_OP_GELU:
- ggml_cuda_op_gelu(ctx, dst);
- break;
- case GGML_UNARY_OP_SILU:
- ggml_cuda_op_silu(ctx, dst);
- break;
- case GGML_UNARY_OP_GELU_QUICK:
- ggml_cuda_op_gelu_quick(ctx, dst);
- break;
- case GGML_UNARY_OP_TANH:
- ggml_cuda_op_tanh(ctx, dst);
- break;
- case GGML_UNARY_OP_RELU:
- ggml_cuda_op_relu(ctx, dst);
- break;
- case GGML_UNARY_OP_SIGMOID:
- ggml_cuda_op_sigmoid(ctx, dst);
- break;
- case GGML_UNARY_OP_HARDSIGMOID:
- ggml_cuda_op_hardsigmoid(ctx, dst);
- break;
- case GGML_UNARY_OP_HARDSWISH:
- ggml_cuda_op_hardswish(ctx, dst);
- break;
- case GGML_UNARY_OP_EXP:
- ggml_cuda_op_exp(ctx, dst);
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_NORM:
- ggml_cuda_op_norm(ctx, dst);
- break;
- case GGML_OP_GROUP_NORM:
- ggml_cuda_op_group_norm(ctx, dst);
- break;
- case GGML_OP_CONCAT:
- ggml_cuda_op_concat(ctx, dst);
- break;
- case GGML_OP_UPSCALE:
- ggml_cuda_op_upscale(ctx, dst);
- break;
- case GGML_OP_PAD:
- ggml_cuda_op_pad(ctx, dst);
- break;
- case GGML_OP_ARANGE:
- ggml_cuda_op_arange(ctx, dst);
- break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- ggml_cuda_op_timestep_embedding(ctx, dst);
- break;
- case GGML_OP_LEAKY_RELU:
- ggml_cuda_op_leaky_relu(ctx, dst);
- break;
- case GGML_OP_RMS_NORM:
- ggml_cuda_op_rms_norm(ctx, dst);
- break;
- case GGML_OP_MUL_MAT:
- if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
- GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
- return false;
- } else {
- ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
- }
- break;
- case GGML_OP_MUL_MAT_ID:
- ggml_cuda_mul_mat_id(ctx, dst);
- break;
- case GGML_OP_OUT_PROD:
- ggml_cuda_out_prod(ctx, dst);
- break;
- case GGML_OP_SCALE:
- ggml_cuda_op_scale(ctx, dst);
- break;
- case GGML_OP_SQR:
- ggml_cuda_op_sqr(ctx, dst);
- break;
- case GGML_OP_SQRT:
- ggml_cuda_op_sqrt(ctx, dst);
- break;
- case GGML_OP_SIN:
- ggml_cuda_op_sin(ctx, dst);
- break;
- case GGML_OP_COS:
- ggml_cuda_op_cos(ctx, dst);
- break;
- case GGML_OP_CLAMP:
- ggml_cuda_op_clamp(ctx, dst);
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- break;
- case GGML_OP_DIAG_MASK_INF:
- ggml_cuda_op_diag_mask_inf(ctx, dst);
- break;
- case GGML_OP_SOFT_MAX:
- ggml_cuda_op_soft_max(ctx, dst);
- break;
- case GGML_OP_ROPE:
- ggml_cuda_op_rope(ctx, dst);
- break;
- case GGML_OP_IM2COL:
- ggml_cuda_op_im2col(ctx, dst);
- break;
- case GGML_OP_CONV_TRANSPOSE_1D:
- ggml_cuda_op_conv_transpose_1d(ctx,dst);
- break;
- case GGML_OP_POOL_2D:
- ggml_cuda_op_pool2d(ctx, dst);
- break;
- case GGML_OP_SUM:
- ggml_cuda_op_sum(ctx, dst);
- break;
- case GGML_OP_SUM_ROWS:
- ggml_cuda_op_sum_rows(ctx, dst);
- break;
- case GGML_OP_ARGSORT:
- ggml_cuda_op_argsort(ctx, dst);
- break;
- case GGML_OP_FLASH_ATTN_EXT:
- ggml_cuda_flash_attn_ext(ctx, dst);
- break;
- case GGML_OP_CROSS_ENTROPY_LOSS:
- ggml_cuda_cross_entropy_loss(ctx, dst);
- break;
- case GGML_OP_RWKV_WKV6:
- ggml_cuda_op_rwkv_wkv6(ctx, dst);
- break;
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- ggml_cuda_cross_entropy_loss_back(ctx, dst);
- break;
- case GGML_OP_OPT_STEP_ADAMW:
- ggml_cuda_opt_step_adamw(ctx, dst);
- break;
- default:
- return false;
- }
-
- cudaError_t err = cudaGetLastError();
- if (err != cudaSuccess) {
- GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
- CUDA_CHECK(err);
- }
-
- return true;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend
-
-static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- return cuda_ctx->name.c_str();
-}
-
-static void ggml_backend_cuda_free(ggml_backend_t backend) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- delete cuda_ctx;
- delete backend;
-}
-
-static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
-
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
-}
-
-static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
-
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
-}
-
-static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
- ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
- ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
-
- if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
- return false;
- }
-
- if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
- return false;
- }
-
- // device -> device copy
- ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
- ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
-
- ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
- ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
-
- if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
-#endif
- return false;
- }
-
- if (backend_src != backend_dst) {
- // copy on src stream
- if (cuda_ctx_src->device == cuda_ctx_dst->device) {
- CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
- } else {
-#ifdef GGML_CUDA_NO_PEER_COPY
- return false;
-#else
- CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
-#endif
- }
-
- // record event on src stream after the copy
- if (!cuda_ctx_src->copy_event) {
- ggml_cuda_set_device(cuda_ctx_src->device);
- CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
- }
-
- CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
-
- // wait on dst stream for the copy to complete
- CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
- } else {
- // src and dst are on the same backend
- CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
- }
- return true;
-}
-
-static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
-
- GGML_UNUSED(backend);
-}
-
-#ifdef USE_CUDA_GRAPH
-static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
- graph_node_properties->node_address = node->data;
- graph_node_properties->node_op = node->op;
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- graph_node_properties->ne[i] = node->ne[i];
- graph_node_properties->nb[i] = node->nb[i];
- }
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
- }
- memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
-}
-
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
- if (node->data != graph_node_properties->node_address &&
- node->op != GGML_OP_CPY &&
- node->op != GGML_OP_VIEW) {
- return false;
- }
-
- if (node->op != graph_node_properties->node_op) {
- return false;
- }
-
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (node->ne[i] != graph_node_properties->ne[i]) {
- return false;
- }
- if (node->nb[i] != graph_node_properties->nb[i]) {
- return false;
- }
- }
-
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (node->src[i] &&
- node->src[i]->data != graph_node_properties->src_address[i] &&
- node->op != GGML_OP_CPY &&
- node->op != GGML_OP_VIEW
- ) {
- return false;
- }
- }
-
- if (node->op == GGML_OP_SCALE &&
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
- return false;
- }
-
- return true;
-}
-#endif
-
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- ggml_cuda_set_device(cuda_ctx->device);
-
-#ifdef USE_CUDA_GRAPH
- static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
- // Objects required for CUDA Graph
- if (cuda_ctx->cuda_graph == nullptr) {
- cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
- }
-
- bool use_cuda_graph = true;
- bool cuda_graph_update_required = false;
- // vector of pointers to CUDA cpy kernels, which are required to identify
- // kernel parameters which need updated in the graph for each token
- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
-
- if (cuda_ctx->cuda_graph->graph == nullptr) {
- if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
- cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
- }
- }
-
- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
- // or previous graph capture failure.
- // Also disable for multi-gpu for now. TO DO investigate
- if (disable_cuda_graphs_due_to_env
- || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
- || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
- || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
- use_cuda_graph = false;
- }
-
- if (use_cuda_graph) {
- if (cuda_ctx->cuda_graph->instance == nullptr) {
- cuda_graph_update_required = true;
- }
-
- // Check if the graph size has changed
- if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
- cuda_graph_update_required = true;
- cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
- }
-
- // Loop over nodes in GGML graph to determine if CUDA graph update is required
- // and store properties to allow this comparison for the next token
- for (int i = 0; i < cgraph->n_nodes; i++) {
- bool has_matching_properties = true;
- if (!cuda_graph_update_required) {
- has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
- }
- if (!has_matching_properties) {
- cuda_graph_update_required = true;
- }
- set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
- }
-
- // Loop over nodes in GGML graph to obtain info needed for CUDA graph
- cuda_ctx->cuda_graph->updated_kernel_arg.clear();
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor * node = cgraph->nodes[i];
-
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
- continue;
- }
-
- if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
- use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
- }
-
- if (node->op == GGML_OP_MUL_MAT_ID) {
- use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
- }
-
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
- // disable CUDA graphs for batch size > 1 for now.
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
- use_cuda_graph = false;
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
- }
-
- if (node->op == GGML_OP_CPY) {
- // store the copy op parameter which changes with each token.
- cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
- // store a pointer to each copy op CUDA kernel to identify it later
- void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
- if (!ptr) {
- use_cuda_graph = false;
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
- } else {
- if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
- ggml_cuda_cpy_fn_ptrs.push_back(ptr);
- }
- }
- }
-
- if (!use_cuda_graph) {
- break;
- }
- }
-
- // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
- if (use_cuda_graph && cuda_graph_update_required) {
- cuda_ctx->cuda_graph->number_consecutive_updates++;
- } else {
- cuda_ctx->cuda_graph->number_consecutive_updates = 0;
- }
-
- if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
- cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
- }
- }
-
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
- CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
- }
-
-#else
- bool use_cuda_graph = false;
- bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
- bool graph_evaluated_or_captured = false;
-
- while (!graph_evaluated_or_captured) {
- // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
- // With the use of CUDA graphs, the execution will be performed by the graph launch.
- if (!use_cuda_graph || cuda_graph_update_required) {
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor * node = cgraph->nodes[i];
-
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
- continue;
- }
-
-#ifndef NDEBUG
- assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- if (node->src[j] != nullptr) {
- assert(node->src[j]->buffer);
- assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
- ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
- }
- }
-#endif
-
- bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
- if (!ok) {
- GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
- }
- GGML_ASSERT(ok);
- }
- }
-
-#ifdef USE_CUDA_GRAPH
- if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
- if (cuda_ctx->cuda_graph->graph != nullptr) {
- CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
- cuda_ctx->cuda_graph->graph = nullptr;
- }
- CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
-
-#if 0
- if (disable_cuda_graphs_due_to_failed_capture) {
- use_cuda_graph = false;
- cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
- } else {
- graph_evaluated_or_captured = true; // CUDA graph has been captured
- }
-#endif
- graph_evaluated_or_captured = true; // CUDA graph has been captured
- } else {
- graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
- }
- }
-
- if (use_cuda_graph) {
- if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
- }
-
- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-
- if (cuda_graph_update_required) {
- // Extract nodes from graph
- // First call with null argument gets number of nodes in graph
- CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
- // Subsequent call with non-null argument gets nodes
- cuda_ctx->cuda_graph->nodes.clear();
- cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
- cuda_ctx->cuda_graph->params.clear();
- cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
- if (cuda_ctx->cuda_graph->num_nodes > 0) {
- CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
- // Loop over nodes, and extract kernel parameters from each node
- for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
- cudaGraphNodeType node_type;
- CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
- if (node_type == cudaGraphNodeTypeKernel) {
- cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
- if (stat == cudaErrorInvalidDeviceFunction) {
- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
- // We don't need to update blas nodes, so clear error and move on.
- cudaGetLastError();
- } else {
- GGML_ASSERT(stat == cudaSuccess);
- }
- }
- }
- }
- }
-
- // One of the arguments to the copy kernel is updated for each token, hence we need to
- // replace that argument with the updated value in the CUDA graph
- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
- int k = 0;
- for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
- if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
- cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
- CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
- }
- }
- }
-
- // Update graph executable
- cudaGraphExecUpdateResultInfo result_info;
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
- if (stat == cudaErrorGraphExecUpdateFailure) {
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
-#endif
- // The pre-existing graph exec cannot be updated due to violated constraints
- // so instead clear error and re-instantiate
- cudaGetLastError();
- CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
- cuda_ctx->cuda_graph->instance = nullptr;
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
- } else {
- GGML_ASSERT(stat == cudaSuccess);
- }
- // Launch graph
- CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
-#else
- graph_evaluated_or_captured = true;
-#endif // USE_CUDA_GRAPH
- }
-
- return GGML_STATUS_SUCCESS;
-}
-
-static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
-}
-
-static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- if (ggml_backend_is_cuda(backend)) {
- CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
- } else {
-#if 0
- // untested
- auto wait_fn = [](void * user_data) {
- ggml_backend_event_t event = (ggml_backend_event_t)user_data;
- ggml_backend_event_synchronize(event);
- };
-
- CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
-#endif
- GGML_ABORT("fatal error");
- }
-}
-
-static const ggml_backend_i ggml_backend_cuda_interface = {
- /* .get_name = */ ggml_backend_cuda_get_name,
- /* .free = */ ggml_backend_cuda_free,
- /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
- /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
- /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
- /* .synchronize = */ ggml_backend_cuda_synchronize,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_cuda_graph_compute,
- /* .event_record = */ ggml_backend_cuda_event_record,
- /* .event_wait = */ ggml_backend_cuda_event_wait,
-};
-
-static ggml_guid_t ggml_backend_cuda_guid() {
- static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
- return &guid;
-}
-
-bool ggml_backend_is_cuda(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
-}
-
-int ggml_backend_cuda_get_device_count() {
- return ggml_cuda_info().device_count;
-}
-
-void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
- cudaDeviceProp prop;
- CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
- snprintf(description, description_size, "%s", prop.name);
-}
-
-void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
- ggml_cuda_set_device(device);
-
- CUDA_CHECK(cudaMemGetInfo(free, total));
-}
-
-bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
- if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
- return false;
- }
-
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
- cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
- if (err != cudaSuccess) {
- // clear the error
- cudaGetLastError();
-
- GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
- size / 1024.0 / 1024.0, cudaGetErrorString(err));
- return false;
- }
- return true;
-#else
- return false;
-#endif
-}
-
-void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
- if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
- return;
- }
-
- cudaError_t err = cudaHostUnregister(buffer);
- if (err != cudaSuccess) {
- // clear the error
- cudaGetLastError();
- }
-}
-
-
-// backend device
-
-struct ggml_backend_cuda_device_context {
- int device;
- std::string name;
- std::string description;
-};
-
-static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
- return ctx->name.c_str();
-}
-
-static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
- return ctx->description.c_str();
-}
-
-static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
- ggml_cuda_set_device(ctx->device);
- CUDA_CHECK(cudaMemGetInfo(free, total));
-}
-
-static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
- props->name = ggml_backend_cuda_device_get_name(dev);
- props->description = ggml_backend_cuda_device_get_description(dev);
- props->type = ggml_backend_cuda_device_get_type(dev);
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
- bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
-#ifdef GGML_CUDA_NO_PEER_COPY
- bool events = false;
-#else
- bool events = true;
-#endif
-
- props->caps = {
- /* .async = */ true,
- /* .host_buffer = */ host_buffer,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ events,
- };
-}
-
-static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
- GGML_UNUSED(params);
- ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
- return ggml_backend_cuda_init(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
- return ggml_backend_cuda_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return ggml_backend_cuda_host_buffer_type();
-}
-
-// TODO: move these functions here
-static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
-
- // split buffers can only be used with GGML_OP_MUL_MAT
- if (op->op != GGML_OP_MUL_MAT) {
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
- return false;
- }
- }
- }
-
- // check if all the sources are allocated on this device
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
- ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
- if (buft_ctx->device != dev_ctx->device) {
- return false;
- }
- }
- }
-
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_NEG:
- case GGML_UNARY_OP_STEP:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_SIGMOID:
- case GGML_UNARY_OP_HARDSIGMOID:
- case GGML_UNARY_OP_HARDSWISH:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_TANH:
- case GGML_UNARY_OP_EXP:
- return ggml_is_contiguous(op->src[0]);
- default:
- return false;
- }
- break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- {
- struct ggml_tensor * a = op->src[0];
- struct ggml_tensor * b = op->src[1];
- if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
- return false;
- }
- if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
- return false;
- }
-#ifdef GGML_USE_MUSA
- if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
- !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
- return false;
- }
-#endif // GGML_USE_MUSA
- switch (a->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_Q8_K:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ3_S:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ4_NL:
- case GGML_TYPE_IQ4_XS:
-#ifdef GGML_USE_MUSA
- if (a->type == GGML_TYPE_Q3_K) {
- return false;
- }
-#endif // GGML_USE_MUSA
- return true;
- default:
- return false;
- }
- } break;
- case GGML_OP_OUT_PROD:
- return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
- case GGML_OP_GET_ROWS:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F16:
- case GGML_TYPE_F32:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- return true;
- default:
- return false;
- }
- } break;
- case GGML_OP_CPY:
- {
- ggml_type src0_type = op->src[0]->type;
- ggml_type src1_type = op->src[1]->type;
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
- return true;
- }
- if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
- return true;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
- return true;
- }
- return false;
- } break;
- case GGML_OP_DUP:
- {
- ggml_type src0_type = op->src[0]->type;
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
- } break;
- case GGML_OP_ARGMAX:
- case GGML_OP_COUNT_EQUAL:
- {
- return true;
- } break;
- case GGML_OP_REPEAT:
- {
- ggml_type src0_type = op->src[0]->type;
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
- } break;
- case GGML_OP_REPEAT_BACK:
- return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
- case GGML_OP_CONCAT:
- {
- ggml_type src0_type = op->src[0]->type;
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
- } break;
- case GGML_OP_CONV_TRANSPOSE_1D:
- {
- ggml_type src0_type = op->src[0]->type;
- ggml_type src1_type = op->src[1]->type;
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- return false;
- } break;
- case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
- return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- case GGML_OP_SUB:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- return true;
- case GGML_OP_CONT:
- return op->src[0]->type != GGML_TYPE_BF16;
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- return true;
- case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
- case GGML_OP_IM2COL:
- case GGML_OP_POOL_2D:
- case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_ARGSORT:
- case GGML_OP_ACC:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
- case GGML_OP_ARANGE:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_LEAKY_RELU:
- case GGML_OP_RWKV_WKV6:
- return true;
- case GGML_OP_FLASH_ATTN_EXT: {
-#ifndef FLASH_ATTN_AVAILABLE
- return false;
-#endif
- if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
- return false;
- }
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
- return true;
- }
- if (op->src[0]->ne[0] == 128) {
- return true;
- }
- if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
- return true;
- }
- const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
- return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
- }
- case GGML_OP_CROSS_ENTROPY_LOSS:
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- case GGML_OP_OPT_STEP_ADAMW:
- return true;
- default:
- return false;
- }
-}
-
-static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
-}
-
-static int64_t get_op_batch_size(const ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_GET_ROWS:
- return 0;
- case GGML_OP_MUL_MAT:
- return op->ne[1];
- case GGML_OP_MUL_MAT_ID:
- case GGML_OP_ROPE:
- return op->ne[2];
- default:
- return ggml_nrows(op);
- }
-}
-
-static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- const int min_batch_size = 32;
-
- return get_op_batch_size(op) >= min_batch_size;
-
- GGML_UNUSED(dev);
-}
-
-static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
-#ifdef GGML_CUDA_NO_PEER_COPY
- return nullptr;
-#else
- ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
-
- ggml_cuda_set_device(dev_ctx->device);
-
- cudaEvent_t event;
- CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
-
- return new ggml_backend_event {
- /* .device = */ dev,
- /* .context = */ event,
- };
-#endif
-}
-
-static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
- GGML_UNUSED(dev);
-
- CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
- delete event;
-}
-
-static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
- GGML_UNUSED(dev);
- CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
-}
-
-static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
- /* .get_name = */ ggml_backend_cuda_device_get_name,
- /* .get_description = */ ggml_backend_cuda_device_get_description,
- /* .get_memory = */ ggml_backend_cuda_device_get_memory,
- /* .get_type = */ ggml_backend_cuda_device_get_type,
- /* .get_props = */ ggml_backend_cuda_device_get_props,
- /* .init_backend = */ ggml_backend_cuda_device_init_backend,
- /* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type,
- /* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type,
- /* .buffer_from_host_ptr = */ NULL,
- /* .supports_op = */ ggml_backend_cuda_device_supports_op,
- /* .supports_buft = */ ggml_backend_cuda_device_supports_buft,
- /* .offload_op = */ ggml_backend_cuda_device_offload_op,
- /* .event_new = */ ggml_backend_cuda_device_event_new,
- /* .event_free = */ ggml_backend_cuda_device_event_free,
- /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
-};
-
-// backend reg
-
-struct ggml_backend_cuda_reg_context {
- std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {
- GGML_UNUSED(reg);
- return GGML_CUDA_NAME;
-}
-
-static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {
- ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
- return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
- GGML_ASSERT(index < ctx->devices.size());
- return ctx->devices[index];
-}
-
-static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- GGML_UNUSED(reg);
- if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
- return (void *)ggml_backend_cuda_split_buffer_type;
- }
- if (strcmp(name, "ggml_backend_register_host_buffer") == 0) {
- return (void *)ggml_backend_cuda_register_host_buffer;
- }
- if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) {
- return (void *)ggml_backend_cuda_unregister_host_buffer;
- }
- return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
- /* .get_name = */ ggml_backend_cuda_reg_get_name,
- /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count,
- /* .get_device_get = */ ggml_backend_cuda_reg_get_device,
- /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address,
-};
-
-// backend registry
-ggml_backend_reg_t ggml_backend_cuda_reg() {
- static ggml_backend_reg reg;
- static bool initialized = false;
-
- {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- if (!initialized) {
- ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
-
- for (int i = 0; i < ggml_cuda_info().device_count; i++) {
- ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
- dev_ctx->device = i;
- dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
-
- ggml_cuda_set_device(i);
- cudaDeviceProp prop;
- CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
- dev_ctx->description = prop.name;
-
- ggml_backend_dev_t dev = new ggml_backend_device {
- /* .interface = */ ggml_backend_cuda_device_interface,
- /* .reg = */ ®,
- /* .context = */ dev_ctx
- };
- ctx->devices.push_back(dev);
- }
-
- reg = ggml_backend_reg {
- /* .interface = */ ggml_backend_cuda_reg_interface,
- /* .context = */ ctx
- };
- }
-
- initialized = true;
- }
-
- return ®
-}
-
-ggml_backend_t ggml_backend_cuda_init(int device) {
- if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
- GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
- return nullptr;
- }
-
- ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
- if (ctx == nullptr) {
- GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
- return nullptr;
- }
-
- ggml_backend_t cuda_backend = new ggml_backend {
- /* .guid = */ ggml_backend_cuda_guid(),
- /* .interface = */ ggml_backend_cuda_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
- /* .context = */ ctx,
- };
-
- return cuda_backend;
-}
+++ /dev/null
-#include "common.cuh"
-#include "rwkv-wkv.cuh"
-
-static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
- const int tid = threadIdx.x;
- const int bid = blockIdx.x;
-
- const int head_size = CUDA_WKV_BLOCK_SIZE;
- const int batch_i = bid / H;
- const int head_i = bid % H;
- const int state_size = C * head_size;
- const int n_seq_tokens = T / B;
-
- float state[head_size];
- __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
-
- #pragma unroll
- for (int i = 0; i < head_size; i++) {
- state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
- }
-
- __syncthreads();
- _tf[tid] = tf[head_i * head_size + tid];
- __syncthreads();
-
- for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
- __syncthreads();
- _k[tid] = k[t];
- _r[tid] = r[t];
- _td[tid] = td[t];
- __syncthreads();
-
- const float _v = v[t];
- float y = 0;
- for (int j = 0; j < head_size; j += 4) {
- const float4& k = (float4&)(_k[j]);
- const float4& r = (float4&)(_r[j]);
- const float4& tf = (float4&)(_tf[j]);
- const float4& td = (float4&)(_td[j]);
- float4& s = (float4&)(state[j]);
- float4 kv;
-
- kv.x = k.x * _v;
- kv.y = k.y * _v;
- kv.z = k.z * _v;
- kv.w = k.w * _v;
-
- y += r.x * (tf.x * kv.x + s.x);
- y += r.y * (tf.y * kv.y + s.y);
- y += r.z * (tf.z * kv.z + s.z);
- y += r.w * (tf.w * kv.w + s.w);
-
- s.x = s.x * td.x + kv.x;
- s.y = s.y * td.y + kv.y;
- s.z = s.z * td.z + kv.z;
- s.w = s.w * td.w + kv.w;
- }
- dst[t] = y;
- }
-
- #pragma unroll
- for (int i = 0; i < head_size; i++) {
- dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
- }
-}
-
-void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const float * k_d = (const float *)dst->src[0]->data;
- const float * v_d = (const float *)dst->src[1]->data;
- const float * r_d = (const float *)dst->src[2]->data;
- const float * tf_d = (const float *)dst->src[3]->data;
- const float * td_d = (const float *)dst->src[4]->data;
- const float * s_d = (const float *)dst->src[5]->data;
-
- const int64_t B = dst->src[5]->ne[1];
- const int64_t T = dst->src[0]->ne[3];
- const int64_t C = dst->ne[0];
- const int64_t H = dst->src[0]->ne[2];
-
- float * dst_d = (float *)dst->data;
-
- cudaStream_t stream = ctx.stream();
-
- GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
- GGML_ASSERT(C % H == 0);
- GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
-
- rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
-}
+++ /dev/null
-#include "common.cuh"
-
-#define CUDA_WKV_BLOCK_SIZE 64
-
-void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+++ /dev/null
-#include "ggml-impl.h"
-#include "ggml-backend.h"
-#include "ggml-backend-impl.h"
-#include "ggml-kompute.h"
-
-// These are generated at build time by cmake custom command
-#include "shaderop_scale.h"
-#include "shaderop_scale_8.h"
-#include "shaderop_add.h"
-#include "shaderop_addrow.h"
-#include "shaderop_mul.h"
-#include "shaderop_silu.h"
-#include "shaderop_relu.h"
-#include "shaderop_gelu.h"
-#include "shaderop_softmax.h"
-#include "shaderop_norm.h"
-#include "shaderop_rmsnorm.h"
-#include "shaderop_diagmask.h"
-#include "shaderop_mul_mat_f16.h"
-#include "shaderop_mul_mat_q8_0.h"
-#include "shaderop_mul_mat_q4_0.h"
-#include "shaderop_mul_mat_q4_1.h"
-#include "shaderop_mul_mat_q4_k.h"
-#include "shaderop_mul_mat_q6_k.h"
-#include "shaderop_mul_mat_mat_f32.h"
-#include "shaderop_getrows_f32.h"
-#include "shaderop_getrows_f16.h"
-#include "shaderop_getrows_q4_0.h"
-#include "shaderop_getrows_q4_1.h"
-#include "shaderop_getrows_q6_k.h"
-#include "shaderop_rope_f16.h"
-#include "shaderop_rope_f32.h"
-#include "shaderop_cpy_f16_f16.h"
-#include "shaderop_cpy_f16_f32.h"
-#include "shaderop_cpy_f32_f16.h"
-#include "shaderop_cpy_f32_f32.h"
-
-#include <algorithm>
-#include <array>
-#include <cassert>
-#include <cstdint>
-#include <cstdio>
-#include <cstring>
-#include <iostream>
-#include <memory>
-#include <mutex>
-#include <stdexcept>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-#include <kompute/Kompute.hpp>
-#include <vulkan/vulkan.hpp>
-
-#ifdef __linux__
-#include <cstdlib> // for setenv
-#endif
-
-#define QK4_0 32
-#define QR4_0 2
-#define QK4_1 32
-#define QK_NL 16
-
-typedef ggml_fp16_t half;
-
-static std::string ggml_kompute_format_name(int device) {
- return "Kompute" + std::to_string(device);
-}
-
-struct ggml_kompute_context {
- int device;
- std::string name;
- std::shared_ptr<vk::DescriptorPool> pool;
-
- ggml_kompute_context(int device)
- : device(device), name(ggml_kompute_format_name(device)) {}
-};
-
-// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
-// and consolidate the init functions and simplify object lifetime management. As it currently stands,
-// we *have* to have the kompute manager no matter what for device discovery, but the kompute context
-// is only created when a device is set and vulkan is explicitly turned on.
-static ggml_kompute_context *s_kompute_context = nullptr;
-
-class kompute_manager {
- kp::Manager *s_mgr = nullptr;
-
-public:
- kp::Manager *operator()() {
- if (s_mgr && !s_mgr->hasInstance()) {
- destroy();
- }
- if (!s_mgr) {
- s_mgr = new kp::Manager;
- }
- return s_mgr;
- }
-
- void destroy() {
- delete s_mgr;
- s_mgr = nullptr;
- }
-};
-
-static kompute_manager komputeManager;
-
-struct ggml_vk_memory {
- void *data = nullptr;
- size_t size = 0;
- vk::DeviceMemory *primaryMemory = nullptr;
- vk::Buffer *primaryBuffer = nullptr;
- vk::DeviceMemory *stagingMemory = nullptr;
- vk::Buffer *stagingBuffer = nullptr;
-};
-
-#ifdef __linux__
-__attribute__((constructor))
-static void enable_sam() {
- setenv("RADV_PERFTEST", "sam", false);
-}
-#endif
-
-static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
- vk::PhysicalDeviceFeatures availableFeatures;
- physical_device.getFeatures(&availableFeatures);
-
- if (!availableFeatures.shaderInt16)
- return false;
-
- vk::PhysicalDeviceVulkan11Features availableFeatures11;
- vk::PhysicalDeviceVulkan12Features availableFeatures12;
-
- availableFeatures11.pNext = &availableFeatures12;
- availableFeatures12.pNext = nullptr;
-
- vk::PhysicalDeviceFeatures2 features2;
- features2.pNext = &availableFeatures11;
-
- physical_device.getFeatures2(&features2);
-
- if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
- !availableFeatures11.storageBuffer16BitAccess) {
- return false;
- }
-
- if (!availableFeatures12.storageBuffer8BitAccess ||
- !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
- !availableFeatures12.shaderFloat16 ||
- !availableFeatures12.shaderInt8) {
- return false;
- }
-
- return true;
-}
-
-static const char * ggml_vk_getVendorName(uint32_t vendorID) {
- switch (vendorID) {
- case 0x10DE:
- return "nvidia";
- case 0x1002:
- return "amd";
- case 0x8086:
- return "intel";
- default:
- return "unknown";
- }
-}
-
-static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
- std::vector<ggml_vk_device> results;
- if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
- return results;
-
- std::vector<vk::PhysicalDevice> physical_devices;
- try {
- physical_devices = komputeManager()->listDevices();
- } catch (vk::SystemError & err) {
- std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
- return results;
- }
-
- uint32_t deviceCount = physical_devices.size();
- if (deviceCount == 0)
- return results;
-
- std::unordered_map<std::string, size_t> count_by_name;
-
- for (uint32_t i = 0; i < deviceCount; i++) {
- const auto & physical_device = physical_devices[i];
-
- VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
- VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
- const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
- const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
- if (major < 1 || minor < 2)
- continue;
-
- if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
- continue;
-
- size_t heapSize = 0;
- for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
- VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
- if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
- heapSize = heap.size;
- break;
- }
- }
-
- if (heapSize < memoryRequired)
- continue;
-
- auto ext_props = physical_device.enumerateDeviceExtensionProperties();
- bool has_maintenance4 = false;
-
- // Check if maintenance4 is supported
- for (const auto & properties : ext_props) {
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
- has_maintenance4 = true;
- }
- }
-
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
- vk::PhysicalDeviceProperties2 dev_props2;
- vk::PhysicalDeviceMaintenance3Properties dev_props3;
- vk::PhysicalDeviceMaintenance4Properties dev_props4;
- dev_props2.pNext = &dev_props3;
- dev_props3.pNext = &subgroup_props;
- if (has_maintenance4) {
- subgroup_props.pNext = &dev_props4;
- }
- physical_device.getProperties2(&dev_props2);
-
- if (subgroup_props.subgroupSize < 32)
- continue;
-
- ggml_vk_device d;
- d.index = i;
- d.type = dev_props.deviceType;
- d.heapSize = heapSize;
- d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
- d.subgroupSize = subgroup_props.subgroupSize;
- d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
-
- if (has_maintenance4) {
- d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
- } else {
- d.maxAlloc = dev_props3.maxMemoryAllocationSize;
- }
-
- std::string name(dev_props.deviceName);
- size_t n_idx = ++count_by_name[name];
- if (n_idx > 1) {
- name += " (" + std::to_string(n_idx) + ")";
- }
- d.name = strdup(name.c_str());
-
- results.push_back(d);
- }
-
- std::stable_sort(results.begin(), results.end(),
- [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
- if (lhs.type != rhs.type) {
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
-
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
- }
- return lhs.heapSize < rhs.heapSize;
- }
- );
-
- return results;
-}
-
-static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
- static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
- return devices;
-}
-
-static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
- devices.erase(
- std::remove_if(devices.begin(), devices.end(),
- [&targetVendor](const ggml_vk_device& device) {
- return device.vendor != targetVendor;
- }),
- devices.end()
- );
-}
-
-static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
- devices.erase(
- std::remove_if(devices.begin(), devices.end(),
- [&targetName](const ggml_vk_device& device) {
- return device.name != targetName;
- }),
- devices.end()
- );
-}
-
-static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
- if (name.empty())
- return false;
-
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
- if (name == "amd" || name == "nvidia" || name == "intel") {
- ggml_vk_filterByVendor(devices, name);
- } else if (name != "gpu") {
- ggml_vk_filterByName(devices, name);
- }
-
- if (devices.empty())
- return false;
-
- *device = devices.front();
- return true;
-}
-
-bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
- return ggml_vk_get_device(device, memoryRequired, std::string(name));
-}
-
-bool ggml_vk_has_vulkan() {
- return komputeManager()->hasVulkan();
-}
-
-bool ggml_vk_has_device() {
- return komputeManager()->hasDevice();
-}
-
-ggml_vk_device ggml_vk_current_device() {
- if (!komputeManager()->hasDevice())
- return ggml_vk_device();
-
- auto devices = ggml_vk_available_devices();
- ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
- GGML_ASSERT(!devices.empty());
- return devices.front();
-}
-
-static
-void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
- std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
- vk::DescriptorPoolSize(
- vk::DescriptorType::eStorageBuffer,
- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
- )
- };
-
- vk::DescriptorPoolCreateInfo descriptorPoolInfo(
- vk::DescriptorPoolCreateFlags(),
- size, // Max sets
- static_cast<uint32_t>(descriptorPoolSizes.size()),
- descriptorPoolSizes.data());
-
- ctx->pool = std::make_shared<vk::DescriptorPool>();
- vk::Result r = komputeManager()->device()->createDescriptorPool(
- &descriptorPoolInfo, nullptr, ctx->pool.get());
- if (r != vk::Result::eSuccess)
- std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
-}
-
-static
-void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
- if (ctx->pool) {
- komputeManager()->device()->destroy(
- *ctx->pool,
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
- ctx->pool = nullptr;
- }
-}
-
-static
-vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
- vk::BufferCreateInfo bufferCreateInfo;
- bufferCreateInfo.size = size;
- bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
- vk::BufferUsageFlagBits::eTransferSrc |
- vk::BufferUsageFlagBits::eTransferDst;
- bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
-
- vk::Buffer *vkBuffer = new vk::Buffer;
- vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
- if (r != vk::Result::eSuccess)
- std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
- return vkBuffer;
-}
-
-static
-vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
-
- uint32_t memoryTypeIndex = -1;
- bool memoryTypeIndexFound = false;
- vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
- for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
- const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
- const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
- if (memoryHeap.size < size) {
- continue;
- }
-
- if (requirements.memoryTypeBits & (1 << i)) {
- if (((memoryProperties.memoryTypes[i]).propertyFlags &
- flags) == flags) {
- memoryTypeIndex = i;
- memoryTypeIndexFound = true;
- if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
- *isHostVisible = true;
- }
- break;
- }
- }
- }
- if (!memoryTypeIndexFound) {
- throw std::runtime_error(
- "Memory type index for buffer creation not found");
- }
-
- vk::MemoryAllocateInfo allocInfo;
- allocInfo.allocationSize = size;
- allocInfo.memoryTypeIndex = memoryTypeIndex;
- vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory;
- vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
- if (r != vk::Result::eSuccess) {
- std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
- throw std::runtime_error("Error allocating vulkan memory.");
- }
- return vkDeviceMemory;
-}
-
-static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
- size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
-
- // If offset is already aligned, return it directly
- if (offset % minStorageBufferOffsetAlignment == 0) {
- return offset;
- }
-
- // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
- return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
-}
-
-static ggml_vk_memory ggml_vk_allocate(size_t size) {
- ggml_vk_memory memory;
- bool isHostVisible = false;
- {
- memory.primaryBuffer = ggml_vk_allocate_buffer(size);
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
- memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
- komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
- if (isHostVisible) {
- vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
- if (r != vk::Result::eSuccess)
- std::cerr << "Error mapping memory" << vk::to_string(r);
- }
- }
-
- if (!isHostVisible) {
- memory.stagingBuffer = ggml_vk_allocate_buffer(size);
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
- vk::MemoryPropertyFlagBits::eHostCoherent |
- vk::MemoryPropertyFlagBits::eHostCached;
- memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
- komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
- vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
- if (r != vk::Result::eSuccess)
- std::cerr << "Error mapping memory" << vk::to_string(r);
- }
-
- memory.size = size;
- return memory;
-}
-
-static void ggml_vk_free_memory(ggml_vk_memory &memory)
-{
- komputeManager()->device()->destroy(
- *memory.primaryBuffer,
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
- if (memory.stagingBuffer) {
- komputeManager()->device()->destroy(
- *memory.stagingBuffer,
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
- }
- komputeManager()->device()->freeMemory(
- *memory.primaryMemory,
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
- if (memory.stagingMemory) {
- komputeManager()->device()->freeMemory(
- *memory.stagingMemory,
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
- }
-}
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
-
-static
-ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
-
- // compatibility with ggml-backend
- GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
-
- ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
-
- const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
-
- GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
-
- offset = uint64_t(ioffs);
- return buf_ctx;
-}
-
-static
-const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
- uint64_t originalOffset = 0;
- auto * res = ggml_vk_find_tensor(t, originalOffset);
- if (!res) {
- static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
- return nullTensor;
- }
-
- // Create a tensor whose memory will be composed of our buffers at the correct offset
- const size_t nelements = ggml_nelements(t);
- size_t nbytes = ggml_nbytes(t);
-
- size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
- if (alignedOffset) {
- *alignedOffset = originalOffset - vulkanOffset;
- nbytes += *alignedOffset;
- }
-
- return komputeManager()->tensor(
- t->data,
- nelements,
- nbytes, kp::Tensor::TensorDataTypes::eFloat,
- res->primaryMemory, res->primaryBuffer,
- res->stagingMemory, res->stagingBuffer,
- vulkanOffset);
-}
-
-static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
- if (size % sizeof(uint32_t) != 0) {
- throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
- }
-
- const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
- size_t count = size / sizeof(uint32_t);
- return std::vector<uint32_t>(data_ptr, data_ptr + count);
-}
-
-inline static
-uint32_t safe_divide(uint32_t a, uint32_t b) {
- if (b <= 1) {
- return a;
- }
- if ((a % b) != 0) {
- fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
- GGML_ABORT("safe_divide result would've had remainder");
- }
- return a / b;
-}
-
-static void ggml_vk_add(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
- int32_t ne0,
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
- kp::shader_data::op_add_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00;
- int32_t nb00, nb01, nb02, nb03;
- int32_t ne10, ne11, ne12, ne13;
- int32_t nb10, nb11, nb12, nb13;
- int32_t ne0;
- int32_t nb0, nb1, nb2, nb3;
- } const pushConsts {
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00,
- nb00, nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb10, nb11, nb12, nb13,
- ne0,
- nb0, nb1, nb2, nb3
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_addrow(kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- uint32_t size, uint32_t row = 0) {
-
- const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
- kp::shader_data::op_addrow_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- uint32_t row;
- } const pushConsts {
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- row
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__))
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
- else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({size});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
- int32_t ne0,
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
- kp::shader_data::op_mul_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00;
- int32_t nb00, nb01, nb02, nb03;
- int32_t ne10, ne11, ne12, ne13;
- int32_t nb10, nb11, nb12, nb13;
- int32_t ne0;
- int32_t nb0, nb1, nb2, nb3;
- } const pushConsts {
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00,
- nb00, nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb10, nb11, nb12, nb13,
- ne0,
- nb0, nb1, nb2, nb3
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_scale(kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& in,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inOff, uint32_t outOff,
- uint32_t size, float scale) {
- const static auto spirv_1 = getSpirvShader(
- kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
- );
- const static auto spirv_8 = getSpirvShader(
- kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
- );
-
- struct PushConstants {
- uint32_t inOff, outOff;
- float scale;
- } const pushConsts {
- safe_divide(inOff, 4), safe_divide(outOff, 4),
- scale
- };
-
- const auto * spirv = &spirv_1;
- std::string name(__func__);
- if (size % 8 == 0) {
- size /= 8;
- name += "_8";
- spirv = &spirv_8;
- }
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({in, out});
- s_algo->setWorkgroup({size});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_xxlu(
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& in,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inOff, uint32_t outOff,
- uint32_t size
-) {
- struct PushConstants {
- uint32_t inOff, outOff;
- } const pushConsts {
- safe_divide(inOff, 4), safe_divide(outOff, 4),
- };
-
- auto name = std::string(__func__) + "_" + suffix;
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({in, out});
- s_algo->setWorkgroup({size});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_silu(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
- kp::shader_data::op_silu_comp_spv_len);
-
- ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_relu(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
- kp::shader_data::op_relu_comp_spv_len);
-
- ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_gelu(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
- kp::shader_data::op_gelu_comp_spv_len);
-
- ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
-}
-
-static void ggml_vk_soft_max(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
- float scale
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
- kp::shader_data::op_softmax_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne01, ne02;
- float scale;
- int32_t mask;
- } pushConsts {
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, ne01, ne02,
- scale,
- bool(inB)
- };
-
- auto & inB_ = inB ? inB : inA;
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
- const uint32_t local_x = 32;
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB_, out});
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_norm_(
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& in,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inOff, uint32_t outOff,
- int32_t ne00, int32_t nb01,
- int32_t nrows, float epsilon
-) {
- GGML_ASSERT(nb01%sizeof(float) == 0);
- GGML_ASSERT(ne00%sizeof(float) == 0);
-
- struct PushConstants {
- uint32_t inOff, outOff;
- uint32_t ne00, nb01;
- float eps;
- } pushConsts {
- safe_divide(inOff, 4), safe_divide(outOff, 4),
- (uint32_t)ne00, (uint32_t)nb01, epsilon
- };
-
- auto name = std::string(__func__) + "_" + suffix;
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({in, out});
- s_algo->setWorkgroup({(uint32_t)nrows});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_norm(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
- kp::shader_data::op_norm_comp_spv_len);
-
- ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_rms_norm(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
- kp::shader_data::op_rmsnorm_comp_spv_len);
-
- ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
-}
-
-static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& in,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inOff, uint32_t outOff,
- uint32_t n_past,
- int32_t ne00, int32_t ne01, int32_t ne02) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
- kp::shader_data::op_diagmask_comp_spv_len);
-
- struct PushConstants {
- uint32_t inOff, outOff;
- uint32_t n_past;
- int32_t ne00, ne01;
- } pushConsts {
- safe_divide(inOff, 4), safe_divide(outOff, 4),
- n_past,
- ne00, ne01
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__))
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
- else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({in, out});
- s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_f16(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02,
- uint32_t nb00, uint32_t nb01, uint32_t nb02,
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
- uint32_t nb10, uint32_t nb11, uint32_t nb12,
- int32_t ne0, int32_t ne1,
- uint32_t r2, uint32_t r3
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
- kp::shader_data::op_mul_mat_f16_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne01, ne02;
- uint32_t nb00, nb01, nb02;
- int32_t ne10, ne11, ne12;
- uint32_t nb10, nb11, nb12;
- int32_t ne0, ne1;
- uint32_t r2, r3;
- } pushConsts {
- safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, ne01, ne02,
- nb00, nb01, nb02,
- ne10, ne11, ne12,
- nb10, nb11, nb12,
- ne0, ne1,
- r2, r3
- };
-
- const unsigned ny = unsigned((ne11 + 4 - 1)/4);
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02,
- uint32_t nb01, uint32_t nb02,
- int32_t ne11, int32_t ne12,
- uint32_t nb11, uint32_t nb12,
- uint32_t nb1, uint32_t nb2) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
- kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne01, ne02, ne11, ne12;
- uint32_t nb01, nb02;
- uint32_t nb11, nb12;
- uint32_t nb1, nb2;
- } pushConsts {
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, ne01, ne02, ne11, ne12,
- nb01, nb02, nb11, nb12,
- nb1, nb2
- };
-
- const uint32_t local_x = ggml_vk_current_device().subgroupSize;
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
- {inA, inB, out}, spirv,
- {unsigned(ne01),
- unsigned(ne11),
- unsigned(std::max(ne12, ne02))
- },
- {local_x},
- {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned(ne01),
- unsigned(ne11),
- unsigned(std::max(ne12, ne02)),
- });
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_impl(
- const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02,
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
- int32_t ne0, int32_t ne1,
- uint32_t r2, uint32_t r3
-) {
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne01, ne02;
- int32_t ne10, ne12;
- int32_t ne0, ne1;
- uint32_t r2, r3;
- } pushConsts {
- safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, ne01, ne02,
- ne10, ne12,
- ne0, ne1,
- r2, r3
- };
-
- auto name = std::string(__func__) + "_" + suffix;
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q4_0(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
- kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
-
- ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q4_1(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
- kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
-
- ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q8_0(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
- kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
-
- ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-static void ggml_vk_mul_mat_q4_k(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
- int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
- int32_t ne1, int32_t r2, int32_t r3
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
- kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
- } pushConsts {
- 0, 0, 0,
- ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_q6_k(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
- int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
-) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
- kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, ne10, ne0, ne1, ne01, gqa;
- } pushConsts {
- inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, ne10, ne0, ne1, ne01, ne12/ne02
- };
-
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(__func__)) {
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(__func__);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_get_rows(
- const std::vector<uint32_t>& spirv,
- const char * suffix,
- unsigned element_size, unsigned qk,
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- int32_t ne00, int32_t nb01, int32_t nb1,
- uint32_t size
-) {
- GGML_ASSERT(nb01%element_size == 0);
- GGML_ASSERT(nb1%sizeof(float) == 0);
- if (qk) GGML_ASSERT(ne00%qk == 0);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t ne00, nb01, nb1;
- } pushConsts {
- safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
- ne00, nb01, nb1
- };
-
- auto name = std::string(__func__) + "_" + suffix;
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({size});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_f32(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
- kp::shader_data::op_getrows_f32_comp_spv_len);
-
- ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_f16(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
- kp::shader_data::op_getrows_f16_comp_spv_len);
-
- ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q4_0(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
- kp::shader_data::op_getrows_q4_0_comp_spv_len);
-
- ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q4_1(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
- kp::shader_data::op_getrows_q4_1_comp_spv_len);
-
- ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q6_k(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
- kp::shader_data::op_getrows_q6_k_comp_spv_len);
- ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
-}
-
-static void ggml_vk_rope(
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& inA,
- const std::shared_ptr<kp::Tensor>& inB,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
- ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
- int32_t ne01, int32_t ne02, int32_t ne03,
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
- int32_t ne0,
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
- GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
-
- static const auto spirv_f16 = getSpirvShader(
- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
- );
- static const auto spirv_f32 = getSpirvShader(
- kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
- );
-
- int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
-
- GGML_ASSERT(nb03 % type_size == 0);
- GGML_ASSERT(nb02 % type_size == 0);
- GGML_ASSERT(nb01 % type_size == 0);
- GGML_ASSERT(nb00 % type_size == 0);
- GGML_ASSERT(nb3 % type_size == 0);
- GGML_ASSERT(nb2 % type_size == 0);
- GGML_ASSERT(nb1 % type_size == 0);
- GGML_ASSERT(nb0 % type_size == 0);
-
- struct PushConstants {
- uint32_t inAOff, inBOff, outOff;
- int32_t n_dims, mode, n_ctx_orig;
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
- uint32_t nb00, nb01, nb02, nb03;
- int32_t ne0;
- uint32_t nb0, nb1, nb2, nb3;
- } pushConsts {
- safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
- n_dims, mode, n_ctx_orig,
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
- nb00, nb01, nb02, nb03,
- ne0,
- nb0, nb1, nb2, nb3
- };
-
- auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name)) {
- s_algo = komputeManager()->algorithm<float, PushConstants>(
- name, s_kompute_context->pool.get(), {inA, inB, out},
- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
- {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
- );
- } else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({inA, inB, out});
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_cpy(
- const std::vector<uint32_t>& spirv,
- uint32_t in_element_size, uint32_t out_element_size,
- kp::Sequence& seq,
- const std::shared_ptr<kp::Tensor>& in,
- const std::shared_ptr<kp::Tensor>& out,
- uint32_t inOff, uint32_t outOff,
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
- int32_t ne0, int32_t ne1, int32_t ne2,
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
- struct PushConstants {
- uint32_t inOff, outOff;
- int32_t ne00, ne01, ne02;
- uint32_t nb00, nb01, nb02, nb03;
- int32_t ne0, ne1, ne2;
- uint32_t nb0, nb1, nb2, nb3;
- } pushConsts {
- safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
- ne00, ne01, ne02,
- nb00, nb01, nb02, nb03,
- ne0, ne1, ne2,
- nb0, nb1, nb2, nb3
- };
-
- std::string name = std::string(__func__)
- + "_i_" + std::to_string(in_element_size)
- + "_o_" + std::to_string(out_element_size);
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
- if (!komputeManager()->hasAlgorithm(name))
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
- else {
- s_algo = komputeManager()->getAlgorithm(name);
- s_algo->setTensors({in, out});
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
- s_algo->setPushConstants<PushConstants>({pushConsts});
- s_algo->updateDescriptors(s_kompute_context->pool.get());
- }
- seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f32_f16(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
- kp::shader_data::op_cpy_f32_f16_comp_spv_len);
- ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f32_f32(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
- kp::shader_data::op_cpy_f32_f32_comp_spv_len);
- ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f16_f16(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
- kp::shader_data::op_cpy_f16_f16_comp_spv_len);
- ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f16_f32(Args&&... args) {
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
- kp::shader_data::op_cpy_f16_f32_comp_spv_len);
- ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
-}
-
-static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_SILU:
- return ggml_is_contiguous(op->src[0]);
- default:
- ;
- }
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_SCALE:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_RMS_NORM:
- case GGML_OP_NORM:
- case GGML_OP_ROPE:
- return true;
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- break;
- default:
- return false;
- }
- switch (op->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- break;
- default:
- return false;
- }
- return true;
- case GGML_OP_DIAG_MASK_INF:
- return op->ne[3] == 1;
- case GGML_OP_GET_ROWS:
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q6_K:
- return op->ne[2] == 1 && op->ne[3] == 1;
- default:
- ;
- }
- return false;
- case GGML_OP_MUL_MAT:
- if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
- return false;
-
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_Q6_K:
- return op->ne[3] == 1;
- case GGML_TYPE_F16:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q4_K:
- return true;
- default:
- ;
- }
- default:
- ;
- }
- return false;
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
- const int n_seq = 8;
-
- // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
- // it to the size of the graph, but I think it can be made smaller?
- ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
-
- std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
-
- for (auto& sequence : sequences) {
- sequence = komputeManager()->sequence();
- }
- for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
- const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
-
- auto& seq = *sequences[seq_idx];
-
- const int node_start = (seq_idx + 0) * n_nodes_per_seq;
- const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
-
- bool any_commands_recorded = false;
-
- for (int i = node_start; i < node_end; ++i) {
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
- struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
- struct ggml_tensor * dst = gf->nodes[i];
- GGML_ASSERT(dst->data != nullptr);
-
- if (ggml_is_empty(dst)) {
- continue;
- }
-
- switch (dst->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- continue; // noop -> next node
- default:
- break;
- }
-
- any_commands_recorded = true;
-
- const int32_t ne00 = src0 ? src0->ne[0] : 0;
- const int32_t ne01 = src0 ? src0->ne[1] : 0;
- const int32_t ne02 = src0 ? src0->ne[2] : 0;
- const int32_t ne03 = src0 ? src0->ne[3] : 0;
-
- const uint32_t nb00 = src0 ? src0->nb[0] : 0;
- const uint32_t nb01 = src0 ? src0->nb[1] : 0;
- const uint32_t nb02 = src0 ? src0->nb[2] : 0;
- const uint32_t nb03 = src0 ? src0->nb[3] : 0;
-
- const int32_t ne10 = src1 ? src1->ne[0] : 0;
- const int32_t ne11 = src1 ? src1->ne[1] : 0;
- const int32_t ne12 = src1 ? src1->ne[2] : 0;
- const int32_t ne13 = src1 ? src1->ne[3] : 0;
-
- const uint32_t nb10 = src1 ? src1->nb[0] : 0;
- const uint32_t nb11 = src1 ? src1->nb[1] : 0;
- const uint32_t nb12 = src1 ? src1->nb[2] : 0;
- const uint32_t nb13 = src1 ? src1->nb[3] : 0;
-
- const int32_t ne0 = dst ? dst->ne[0] : 0;
- const int32_t ne1 = dst ? dst->ne[1] : 0;
- const int32_t ne2 = dst ? dst->ne[2] : 0;
-// const int32_t ne3 = dst ? dst->ne[3] : 0;
-
- const uint32_t nb0 = dst ? dst->nb[0] : 0;
- const uint32_t nb1 = dst ? dst->nb[1] : 0;
- const uint32_t nb2 = dst ? dst->nb[2] : 0;
- const uint32_t nb3 = dst ? dst->nb[3] : 0;
-
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
-
- const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
- uint32_t off_src0 = 0;
- uint32_t off_src1 = 0;
- uint32_t off_dst = 0;
- const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
- const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
- const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
-
- switch (dst->op) {
- case GGML_OP_ADD:
- {
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- // src1 is a row
- ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
- } else {
- ggml_vk_add(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne03,
- nb00, nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb10, nb11, nb12, nb13,
- ne0,
- nb0, nb1, nb2, nb3
- );
- }
- } break;
- case GGML_OP_MUL:
- {
- ggml_vk_mul(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne03,
- nb00, nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb10, nb11, nb12, nb13,
- ne0,
- nb0, nb1, nb2, nb3
- );
- } break;
- case GGML_OP_SCALE:
- {
- float scale; memcpy(&scale, dst->op_params, sizeof(float));
-
- ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
- } break;
- case GGML_OP_UNARY:
- {
- int64_t n = ggml_nelements(dst);
- GGML_ASSERT(n % 4 == 0);
- switch (ggml_get_unary_op(gf->nodes[i])) {
- case GGML_UNARY_OP_SILU:
- {
- ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
- } break;
- case GGML_UNARY_OP_RELU:
- {
- ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
- } break;
- case GGML_UNARY_OP_GELU:
- {
- GGML_ASSERT(n % 8 == 0);
- ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
- } break;
- default:
- {
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ABORT("fatal error");
- }
- }
- } break;
- case GGML_OP_SOFT_MAX:
- {
- float scale;
- float max_bias;
-
- memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
-
-#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
-#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
- GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
-
-#pragma message("TODO: add ALiBi support")
-#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
- GGML_ASSERT(max_bias == 0.0f);
-
- ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
- } break;
- case GGML_OP_DIAG_MASK_INF:
- {
- const int n_past = ((int32_t *)(dst->op_params))[0];
- ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
- } break;
- case GGML_OP_NORM:
- {
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
- ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
- } break;
- case GGML_OP_RMS_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
- ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
- } break;
- case GGML_OP_MUL_MAT:
- {
- GGML_ASSERT(ne00 == ne10);
-
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- const uint32_t r2 = ne12/ne02;
- const uint32_t r3 = ne13/ne03;
-
- if (src1t != GGML_TYPE_F32) {
- fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
- goto not_implemented;
- }
-
- if (ggml_is_transposed(src0) ||
- ggml_is_transposed(src1)) {
- fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
- goto not_implemented;
- }
-
- switch (src0t) {
- case GGML_TYPE_F32:
- ggml_vk_mul_mat_mat_f32(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
- );
- break;
- case GGML_TYPE_F16:
- ggml_vk_mul_mat_f16(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
- ne0, ne1, r2, r3
- );
- break;
- case GGML_TYPE_Q8_0:
- ggml_vk_mul_mat_q8_0(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
- );
- break;
- case GGML_TYPE_Q4_0:
- ggml_vk_mul_mat_q4_0(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
- );
- break;
- case GGML_TYPE_Q4_1:
- ggml_vk_mul_mat_q4_1(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
- );
- break;
- case GGML_TYPE_Q4_K:
- ggml_vk_mul_mat_q4_k(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
- );
- break;
- case GGML_TYPE_Q6_K:
- ggml_vk_mul_mat_q6_k(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
- ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
- );
- break;
- default: {
- fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
- goto not_implemented;
- }
- }
-
- } break;
- case GGML_OP_GET_ROWS:
- {
- if (src0t == GGML_TYPE_F32) {
- ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
- } else if (src0t == GGML_TYPE_F16) {
- ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
- } else if (src0t == GGML_TYPE_Q4_0) {
- ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
- } else if (src0t == GGML_TYPE_Q4_1) {
- ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
- } else if (src0t == GGML_TYPE_Q6_K) {
- ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
- } else {
- fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
- goto not_implemented;
- }
- } break;
- case GGML_OP_ROPE:
- {
-#pragma message("TODO: implement phi3 frequency factors support")
-#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
-
-#pragma message("TODO: update rope NORM mode to match NEOX mode")
-#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
-
- GGML_ASSERT(ne10 == ne02);
- GGML_ASSERT(src0t == dstt);
- // const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
- ggml_vk_rope(
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
- ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
- );
- } break;
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- {
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- switch (dstt) {
- case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
- case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
- default: goto not_implemented;
- }
- } break;
- case GGML_TYPE_F16:
- {
- switch (dstt) {
- case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
- case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
- default: goto not_implemented;
- } break;
- default: goto not_implemented;
- }
- }
- } break;
- default: goto not_implemented;
- }
- continue;
- not_implemented: {}
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- //GGML_ABORT("fatal error");
- }
-
- // Evaluate sequence
- if (any_commands_recorded) {
- seq.evalAsync();
- }
- }
-
- // Wait for all sequences to finish
- for (auto& sequence : sequences) {
- if (sequence->isRunning())
- sequence->evalAwait();
- }
-
- ggml_vk_free_descriptor_pool(ctx);
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT<half>::dataType()
-{
- return TensorDataTypes::eFloat;
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT<uint8_t>::dataType()
-{
- return TensorDataTypes::eUnsignedInt;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend interface
-
-struct ggml_backend_kompute_buffer_type_context {
- int device;
- int device_ref = 0;
- uint64_t buffer_alignment;
- uint64_t max_alloc;
- std::string name;
-
- ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
- : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
-};
-
-static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-
- if (!ctx->device_ref) {
- komputeManager()->initializeDevice(
- ctx->device, {}, {
- "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
- "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
- }
- );
- }
-
- assert(ggml_vk_has_device());
- ctx->device_ref++;
-}
-
-static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-
- assert(ctx->device_ref > 0);
-
- ctx->device_ref--;
-
- if (!ctx->device_ref) {
- komputeManager.destroy();
- }
-}
-
-static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- auto * memory = (ggml_vk_memory *)buffer->context;
- if (ggml_vk_has_device()) {
- ggml_vk_free_memory(*memory);
- }
- delete memory;
-}
-
-static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
- return ((ggml_vk_memory *)buffer->context)->data;
-}
-
-static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- GGML_UNUSED(buffer);
-
- const auto res = ggml_vk_get_tensor(tensor);
- GGML_ASSERT(res);
-
- memcpy((char *)tensor->data + offset, data, size);
-
- komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
-}
-
-static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- GGML_UNUSED(buffer);
-
- const auto res = ggml_vk_get_tensor(tensor);
- GGML_ASSERT(res);
-
- komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
-
- memcpy(data, (const char *)tensor->data + offset, size);
-}
-
-static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- auto * memory = (ggml_vk_memory *)buffer->context;
- memset(memory->data, value, buffer->size);
-
- if (memory->stagingBuffer)
- komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
- /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
- /* .get_base = */ ggml_backend_kompute_buffer_get_base,
- /* .init_tensor = */ NULL,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
- /* .cpy_tensor = */ NULL,
- /* .clear = */ ggml_backend_kompute_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// default buffer type
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
- return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- ggml_backend_kompute_device_ref(buft);
- auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
- return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
-}
-
-static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
- return ctx->buffer_alignment;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
- return ctx->max_alloc;
-}
-
-static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
- /* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- auto devices = ggml_vk_available_devices();
- int32_t device_count = (int32_t) devices.size();
- GGML_ASSERT(device < device_count);
- GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
-
- static ggml_backend_buffer_type
- ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
-
- static bool ggml_backend_kompute_buffer_type_initialized = false;
-
- if (!ggml_backend_kompute_buffer_type_initialized) {
- for (int32_t i = 0; i < device_count; i++) {
- ggml_backend_kompute_buffer_types[i] = {
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
- /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
- };
- }
- ggml_backend_kompute_buffer_type_initialized = true;
- }
-
- return &ggml_backend_kompute_buffer_types[device];
-}
-
-// backend
-
-static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
- return ctx->name.c_str();
-}
-
-static void ggml_backend_kompute_free(ggml_backend_t backend) {
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
-
- assert(ctx == s_kompute_context);
- s_kompute_context = nullptr;
- if (ctx != nullptr) {
- delete ctx;
- }
-
- delete backend;
-}
-
-static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
- ggml_vk_graph_compute(ctx, cgraph);
- return GGML_STATUS_SUCCESS;
-}
-
-static struct ggml_backend_i kompute_backend_i = {
- /* .get_name = */ ggml_backend_kompute_name,
- /* .free = */ ggml_backend_kompute_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_kompute_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_kompute_guid() {
- static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
- return &guid;
-}
-
-ggml_backend_t ggml_backend_kompute_init(int device) {
- GGML_ASSERT(s_kompute_context == nullptr);
- s_kompute_context = new ggml_kompute_context(device);
-
- ggml_backend_t kompute_backend = new ggml_backend {
- /* .guid = */ ggml_backend_kompute_guid(),
- /* .interface = */ kompute_backend_i,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
- /* .context = */ s_kompute_context,
- };
-
- return kompute_backend;
-}
-
-bool ggml_backend_is_kompute(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
-}
-
-static size_t ggml_backend_kompute_get_device_count() {
- auto devices = ggml_vk_available_devices();
- return devices.size();
-}
-
-static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
- auto devices = ggml_vk_available_devices();
- GGML_ASSERT((size_t) device < devices.size());
- snprintf(description, description_size, "%s", devices[device].name);
-}
-
-static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
- auto devices = ggml_vk_available_devices();
- GGML_ASSERT((size_t) device < devices.size());
- *total = devices[device].heapSize;
- *free = devices[device].heapSize;
-}
-
-//////////////////////////
-
-struct ggml_backend_kompute_device_context {
- int device;
- std::string name;
- std::string description;
-};
-
-static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- return ctx->name.c_str();
-}
-
-static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- return ctx->description.c_str();
-}
-
-static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- ggml_backend_kompute_get_device_memory(ctx->device, free, total);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- return ggml_backend_kompute_buffer_type(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
- return false;
- }
-
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
-
- return buft_ctx->device == ctx->device;
-}
-
-static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_kompute_device_get_name(dev);
- props->description = ggml_backend_kompute_device_get_description(dev);
- props->type = ggml_backend_kompute_device_get_type(dev);
- ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
- props->caps = {
- /* async = */ false,
- /* host_buffer = */ false,
- /* .buffer_from_host_ptr = */ false,
- /* events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
- GGML_UNUSED(params);
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
- return ggml_backend_kompute_init(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- const int min_batch_size = 32;
-
- return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
- GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
- /* .get_name = */ ggml_backend_kompute_device_get_name,
- /* .get_description = */ ggml_backend_kompute_device_get_description,
- /* .get_memory = */ ggml_backend_kompute_device_get_memory,
- /* .get_type = */ ggml_backend_kompute_device_get_type,
- /* .get_props = */ ggml_backend_kompute_device_get_props,
- /* .init_backend = */ ggml_backend_kompute_device_init,
- /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
- /* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ NULL,
- /* .supports_op = */ ggml_backend_kompute_device_supports_op,
- /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
- /* .offload_op = */ ggml_backend_kompute_device_offload_op,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
- GGML_UNUSED(reg);
- return "Kompute";
-}
-
-static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
- GGML_UNUSED(reg);
- return ggml_backend_kompute_get_device_count();
-}
-
-static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
- static std::vector<ggml_backend_dev_t> devices;
-
- static bool initialized = false;
-
- {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- if (!initialized) {
- for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
- ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
- char desc[256];
- ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
- ctx->device = i;
- ctx->name = "Kompute" + std::to_string(i);
- ctx->description = desc;
- devices.push_back(new ggml_backend_device {
- /* .iface = */ ggml_backend_kompute_device_i,
- /* .reg = */ reg,
- /* .context = */ ctx,
- });
- }
- initialized = true;
- }
- }
-
- GGML_ASSERT(device < devices.size());
- return devices[device];
-}
-
-static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
- /* .get_name = */ ggml_backend_kompute_reg_get_name,
- /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
- /* .get_device = */ ggml_backend_kompute_reg_get_device,
- /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_kompute_reg() {
- static ggml_backend_reg reg = {
- /* .iface = */ ggml_backend_kompute_reg_i,
- /* .context = */ nullptr,
- };
-
- return ®
-}
+++ /dev/null
-#import "ggml-metal.h"
-
-#import "ggml-impl.h"
-#import "ggml-backend-impl.h"
-
-#import <Foundation/Foundation.h>
-
-#import <Metal/Metal.h>
-
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-// max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 64
-
-// max number of MTLCommandBuffer used to submit a graph for processing
-#define GGML_METAL_MAX_COMMAND_BUFFERS 8
-
-#define UNUSED(x) (void)(x)
-
-// globals
-
-// overload of MTLGPUFamilyMetal3 (not available in some environments)
-static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
-
-// initialized in ggml_backend_metal_reg
-static struct ggml_backend_reg g_ggml_backend_metal_reg;
-static struct ggml_backend_device g_ggml_backend_metal_device;
-
-// information about a Metal device
-// note: assumes single GPU device - the default one
-// TODO: support multiple GPU devices
-static struct ggml_backend_metal_device_context {
- id<MTLDevice> mtl_device;
- int mtl_device_ref_count;
-
- bool has_simdgroup_reduction;
- bool has_simdgroup_mm;
- bool has_bfloat;
- bool use_bfloat;
-
- char name[128];
-} g_ggml_ctx_dev_main = {
- /*.mtl_device =*/ nil,
- /*.mtl_device_ref_count =*/ 0,
- /*.has_simdgroup_reduction =*/ false,
- /*.has_simdgroup_mm =*/ false,
- /*.has_bfloat =*/ false,
- /*.use_bfloat =*/ false,
- /*.name =*/ "",
-};
-
-// acquire
-static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
- assert(ctx != NULL);
-
- if (ctx->mtl_device == nil) {
- ctx->mtl_device = MTLCreateSystemDefaultDevice();
-
- ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
- ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
-
- ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
-
- ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
- ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
-
-#if defined(GGML_METAL_USE_BF16)
- ctx->use_bfloat = ctx->has_bfloat;
-#else
- ctx->use_bfloat = false;
-#endif
-
- strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
- }
-
- ctx->mtl_device_ref_count++;
-
- return ctx->mtl_device;
-}
-
-// release
-static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
- assert(ctx != NULL);
- assert(ctx->mtl_device_ref_count > 0);
-
- ctx->mtl_device_ref_count--;
-
- if (ctx->mtl_device_ref_count == 0) {
- [ctx->mtl_device release];
- ctx->mtl_device = nil;
- }
-}
-
-// kernels
-
-struct ggml_metal_kernel {
- id<MTLComputePipelineState> pipeline;
-};
-
-enum ggml_metal_kernel_type {
- GGML_METAL_KERNEL_TYPE_ADD,
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
- GGML_METAL_KERNEL_TYPE_SUB,
- GGML_METAL_KERNEL_TYPE_SUB_ROW,
- GGML_METAL_KERNEL_TYPE_MUL,
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
- GGML_METAL_KERNEL_TYPE_DIV,
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
- GGML_METAL_KERNEL_TYPE_REPEAT_F32,
- GGML_METAL_KERNEL_TYPE_REPEAT_F16,
- GGML_METAL_KERNEL_TYPE_REPEAT_I32,
- GGML_METAL_KERNEL_TYPE_REPEAT_I16,
- GGML_METAL_KERNEL_TYPE_SCALE,
- GGML_METAL_KERNEL_TYPE_SCALE_4,
- GGML_METAL_KERNEL_TYPE_CLAMP,
- GGML_METAL_KERNEL_TYPE_TANH,
- GGML_METAL_KERNEL_TYPE_RELU,
- GGML_METAL_KERNEL_TYPE_SIGMOID,
- GGML_METAL_KERNEL_TYPE_GELU,
- GGML_METAL_KERNEL_TYPE_GELU_4,
- GGML_METAL_KERNEL_TYPE_GELU_QUICK,
- GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
- GGML_METAL_KERNEL_TYPE_SILU,
- GGML_METAL_KERNEL_TYPE_SILU_4,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
- GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
- GGML_METAL_KERNEL_TYPE_GROUP_NORM,
- GGML_METAL_KERNEL_TYPE_NORM,
- GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
- GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
- GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
- GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
- GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
- GGML_METAL_KERNEL_TYPE_IM2COL_F16,
- GGML_METAL_KERNEL_TYPE_IM2COL_F32,
- GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
- GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
- GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
- GGML_METAL_KERNEL_TYPE_PAD_F32,
- GGML_METAL_KERNEL_TYPE_ARANGE_F32,
- GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
- GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
- GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
- GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
- GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
- GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
- GGML_METAL_KERNEL_TYPE_CONCAT,
- GGML_METAL_KERNEL_TYPE_SQR,
- GGML_METAL_KERNEL_TYPE_SQRT,
- GGML_METAL_KERNEL_TYPE_SIN,
- GGML_METAL_KERNEL_TYPE_COS,
- GGML_METAL_KERNEL_TYPE_SUM_ROWS,
- GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
- GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
-
- GGML_METAL_KERNEL_TYPE_COUNT
-};
-
-struct ggml_backend_metal_context {
- id<MTLCommandQueue> queue;
-
- dispatch_queue_t d_queue;
-
- struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
-
- // capture state
- bool capture_next_compute;
- bool capture_started;
-
- id<MTLCaptureScope> capture_scope;
-
- // command buffer state
- int n_cb; // number of extra threads used to submit the command buffers
- int n_nodes_0; // number of nodes submitted by the main thread
- int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
- int n_nodes_per_cb;
-
- struct ggml_cgraph * gf;
-
- // the callback given to the thread pool
- void (^encode_async)(size_t ith);
-
- // n_cb command buffers + 1 used by the main thread
- id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
-
- // abort ggml_metal_graph_compute if callback returns true
- ggml_abort_callback abort_callback;
- void * abort_callback_data;
-};
-
-// MSL code
-// TODO: move the contents here when ready
-// for now it is easier to work in a separate file
-// static NSString * const msl_library_source = @"see metal.metal";
-
-// Here to assist with NSBundle Path Hack
-@interface GGMLMetalClass : NSObject
-@end
-@implementation GGMLMetalClass
-@end
-
-static void * ggml_metal_host_malloc(size_t n) {
- void * data = NULL;
-
-#if TARGET_OS_OSX
- kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
- if (err != KERN_SUCCESS) {
- GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
- return NULL;
- }
-#else
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
- if (result != 0) {
- GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
- return NULL;
- }
-#endif
-
- return data;
-}
-
-static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
- GGML_LOG_INFO("%s: allocating\n", __func__);
-
-#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
- // Show all the Metal device instances in the system
- NSArray * devices = MTLCopyAllDevices();
- for (id<MTLDevice> device in devices) {
- GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
- }
- [devices release]; // since it was created by a *Copy* C method
-#endif
-
- // init context
- struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
- struct ggml_backend_metal_device_context * ctx_dev = dev->context;
-
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
- GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
-
- ctx->queue = [device newCommandQueue];
- ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
-
- id<MTLLibrary> metal_library;
-
- // load library
- //
- // - first check if the library is embedded
- // - then check if the library is in the bundle
- // - if not found, load the source and compile it
- // - if that fails, return NULL
- {
- NSBundle * bundle = nil;
-#ifdef SWIFT_PACKAGE
- bundle = SWIFTPM_MODULE_BUNDLE;
-#else
- bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-#endif
-
- NSError * error = nil;
-
-#if GGML_METAL_EMBED_LIBRARY
- const bool try_metallib = false;
-#else
- const bool try_metallib = true;
-#endif
-
- NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
- if (try_metallib && path_lib != nil) {
- // pre-compiled library found
- NSURL * libURL = [NSURL fileURLWithPath:path_lib];
- GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
-
- metal_library = [device newLibraryWithURL:libURL error:&error];
- if (error) {
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
- return NULL;
- }
- } else {
-#if GGML_METAL_EMBED_LIBRARY
- GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
-
- extern const char ggml_metallib_start[];
- extern const char ggml_metallib_end[];
-
- NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
-#else
- GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
-
- NSString * path_source;
- NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
-
- GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
-
- if (path_resource) {
- path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
- } else {
- path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
- }
-
- if (path_source == nil) {
- GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
- path_source = @"ggml-metal.metal";
- }
-
- GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
-
- NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
- if (error) {
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
- return NULL;
- }
-#endif // GGML_METAL_EMBED_LIBRARY
-
- @autoreleasepool {
- // dictionary of preprocessor macros
- NSMutableDictionary * prep = [NSMutableDictionary dictionary];
-
- if (ctx_dev->use_bfloat) {
- [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
- }
-
- MTLCompileOptions * options = [MTLCompileOptions new];
- options.preprocessorMacros = prep;
-
- //[options setFastMathEnabled:false];
-
- metal_library = [device newLibraryWithSource:src options:options error:&error];
- if (error) {
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
- return NULL;
- }
-
-#if !__has_feature(objc_arc)
- [options release];
-#endif
- }
-#if GGML_METAL_EMBED_LIBRARY
- [src release];
-#endif // GGML_METAL_EMBED_LIBRARY
- }
- }
-
- // print MTL GPU family:
- GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
-
- // determine max supported GPU family
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
- {
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
- if ([device supportsFamily:i]) {
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
- break;
- }
- }
-
- for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
- if ([device supportsFamily:i]) {
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
- break;
- }
- }
-
- for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
- if ([device supportsFamily:i]) {
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
- break;
- }
- }
- }
-
- GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
- GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
- GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
- GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
- GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
-
- ctx->capture_next_compute = false;
- ctx->capture_started = false;
- ctx->capture_scope = nil;
-
- ctx->gf = nil;
- ctx->encode_async = nil;
- for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
- ctx->command_buffers[i] = nil;
- }
-
-#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
- if (@available(macOS 10.12, iOS 16.0, *)) {
- GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
- }
-#endif
-
- // load kernels
- {
- NSError * error = nil;
-
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
- ctx->kernels[i].pipeline = nil;
- }
-
-#define GGML_METAL_ADD_KERNEL(e, name, supported) \
- if (supported) { \
- struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
- id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
- kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
- GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
- (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
- (int) kernel->pipeline.threadExecutionWidth); \
- [metal_function release]; \
- if (error) { \
- GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
- [metal_library release]; \
- return NULL; \
- } \
- } else { \
- GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
- }
-
- const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
- const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
- const bool use_bfloat = ctx_dev->use_bfloat;
-
- // simd_sum and simd_max requires MTLGPUFamilyApple7
-
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
- }
-
- [metal_library release];
-
- return ctx;
-}
-
-static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
- GGML_LOG_INFO("%s: deallocating\n", __func__);
-
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
- [ctx->kernels[i].pipeline release];
- }
-
- Block_release(ctx->encode_async);
-
- [ctx->queue release];
-
- dispatch_release(ctx->d_queue);
-
- free(ctx);
-}
-
-// temporarily defined here for compatibility between ggml-backend and the old API
-
-struct ggml_backend_metal_buffer {
- void * data;
- size_t size;
-
- id<MTLBuffer> metal;
-};
-
-struct ggml_backend_metal_buffer_context {
- void * all_data;
- size_t all_size;
- bool owned;
-
- // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
- int n_buffers;
- struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
-};
-
-// finds the Metal buffer that contains the tensor data on the GPU device
-// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
-// Metal buffer based on the host memory pointer
-//
-static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
- //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
-
- const int64_t tsize = ggml_nbytes(t);
-
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
-
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
-
- // find the view that contains the tensor fully
- for (int i = 0; i < buf_ctx->n_buffers; ++i) {
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
-
- //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
- *offs = (size_t) ioffs;
-
- //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
-
- return buf_ctx->buffers[i].metal;
- }
- }
-
- GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
-
- return nil;
-}
-
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
- const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
- const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
- const bool use_bfloat = ctx_dev->use_bfloat;
-
- if (!use_bfloat) {
- for (size_t i = 0, n = 3; i < n; ++i) {
- if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
- return false;
- }
- }
- }
-
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_TANH:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_SIGMOID:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_SILU:
- return ggml_is_contiguous(op->src[0]);
- default:
- return false;
- }
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- case GGML_OP_CONCAT:
- case GGML_OP_ADD:
- case GGML_OP_SUB:
- case GGML_OP_ACC:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_REPEAT:
- case GGML_OP_SCALE:
- case GGML_OP_CLAMP:
- return true;
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- return ggml_is_contiguous(op->src[0]);
- case GGML_OP_SUM_ROWS:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_RMS_NORM:
- case GGML_OP_GROUP_NORM:
- return has_simdgroup_reduction;
- case GGML_OP_NORM:
- case GGML_OP_ROPE:
- return true;
- case GGML_OP_IM2COL:
- return op->src[0]->type == GGML_TYPE_F16;
- case GGML_OP_POOL_1D:
- return false;
- case GGML_OP_POOL_2D:
- case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
- case GGML_OP_ARANGE:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_ARGSORT:
- case GGML_OP_LEAKY_RELU:
- return true;
- case GGML_OP_FLASH_ATTN_EXT:
- if (op->src[1]->type != op->src[2]->type) {
- return false;
- }
- return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
- case GGML_OP_SSM_CONV:
- case GGML_OP_SSM_SCAN:
- return true;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- return has_simdgroup_reduction &&
- (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
- case GGML_OP_CPY:
- case GGML_OP_DUP:
- case GGML_OP_CONT:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- switch (op->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_BF16:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_IQ4_NL:
- return true;
- default:
- return false;
- }
- case GGML_TYPE_F16:
- switch (op->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- return true;
- default:
- return false;
- }
- case GGML_TYPE_BF16:
- switch (op->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_BF16:
- return true;
- default:
- return false;
- }
- default:
- return false;
- };
- }
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_GET_ROWS:
- {
- return op->ne[3] == 1;
- }
- default:
- return false;
- }
-}
-
-static void ggml_metal_encode_node(
- ggml_backend_t backend,
- int idx,
- id<MTLComputeCommandEncoder> encoder) {
- struct ggml_backend_metal_context * ctx = backend->context;
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
- struct ggml_cgraph * gf = ctx->gf;
-
- struct ggml_tensor * node = ggml_graph_node(gf, idx);
-
- //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
-
- struct ggml_tensor * src0 = node->src[0];
- struct ggml_tensor * src1 = node->src[1];
- struct ggml_tensor * src2 = node->src[2];
- struct ggml_tensor * dst = node;
-
- if (ggml_is_empty(dst)) {
- return;
- }
-
- switch (dst->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- {
- // noop -> next node
- } return;
- default:
- {
- } break;
- }
-
- if (!ggml_metal_supports_op(ctx_dev, dst)) {
- GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
- GGML_ABORT("unsupported op");
- }
-
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
-
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
-
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
- const int64_t ne13 = src1 ? src1->ne[3] : 0;
-
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
- const uint64_t nb13 = src1 ? src1->nb[3] : 0;
-
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
- const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
-
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
-
- const int64_t ne0 = dst ? dst->ne[0] : 0;
- const int64_t ne1 = dst ? dst->ne[1] : 0;
- const int64_t ne2 = dst ? dst->ne[2] : 0;
- const int64_t ne3 = dst ? dst->ne[3] : 0;
-
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
-
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
-
- size_t offs_src0 = 0;
- size_t offs_src1 = 0;
- size_t offs_src2 = 0;
- size_t offs_dst = 0;
-
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
-
-#if 0
- GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
- if (src0) {
- GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
- ggml_is_contiguous(src0), src0->name);
- }
- if (src1) {
- GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
- ggml_is_contiguous(src1), src1->name);
- }
- if (dst) {
- GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
- dst->name);
- }
-#endif
-
- id<MTLDevice> device = ctx_dev->mtl_device;
-
- switch (dst->op) {
- case GGML_OP_CONCAT:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
-
- const int32_t dim = ((const int32_t *) dst->op_params)[0];
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ADD:
- case GGML_OP_SUB:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
-
- const size_t offs = 0;
-
- bool bcast_row = false;
-
- int64_t nb = ne00; // used by the "row" kernels
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- // src1 is a row
- GGML_ASSERT(ne11 == 1);
-
- nb = ne00 / 4;
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
-
- bcast_row = true;
- } else {
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
-
- if (bcast_row) {
- const int64_t n = ggml_nelements(dst)/4;
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } else {
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
- } break;
- case GGML_OP_REPEAT:
- {
- id<MTLComputePipelineState> pipeline;
-
- switch (src0t) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
- case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ACC:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- GGML_ASSERT(dstt == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
- const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
- const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
- const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
- const size_t offs = ((const int32_t *) dst->op_params)[3];
-
- const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- // run a separete kernel to cpy src->dst
- // not sure how to avoid this
- // TODO: make a simpler cpy_bytes kernel
-
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
-
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_SCALE:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- float scale;
- memcpy(&scale, dst->op_params, sizeof(scale));
-
- int64_t n = ggml_nelements(dst);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (n % 4 == 0) {
- n /= 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_CLAMP:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
-
- float min;
- float max;
- memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
- memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(node)) {
- // we are not taking into account the strides, so for now require contiguous tensors
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- case GGML_UNARY_OP_TANH:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_RELU:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SIGMOID:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU:
- {
- int64_t n = ggml_nelements(dst);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU_QUICK:
- {
- int64_t n = ggml_nelements(dst);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SILU:
- {
- int64_t n = ggml_nelements(dst);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- default:
- {
- GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_SQR:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SQRT:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SIN:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_COS:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SUM_ROWS:
- {
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SOFT_MAX:
- {
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-
- int nth = 32; // SIMD width
-
- id<MTLComputePipelineState> pipeline = nil;
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
- if (ne00%4 == 0) {
- while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
- }
- } else {
- while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
- }
- }
-
- float scale;
- float max_bias;
-
- memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
-
- const int64_t nrows_x = ggml_nrows(src0);
- const int64_t nrows_y = src0->ne[1];
-
- const uint32_t n_head = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- if (id_src1) {
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_DIAG_MASK_INF:
- {
- const int n_past = ((const int32_t *)(dst->op_params))[0];
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (ne00%8 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
-
- if (ne00%8 == 0) {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- else {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- } break;
- case GGML_OP_SSM_CONV:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SSM_SCAN:
- {
- struct ggml_tensor * src3 = node->src[3];
- struct ggml_tensor * src4 = node->src[4];
- struct ggml_tensor * src5 = node->src[5];
-
- GGML_ASSERT(src3);
- GGML_ASSERT(src4);
- GGML_ASSERT(src5);
-
- size_t offs_src3 = 0;
- size_t offs_src4 = 0;
- size_t offs_src5 = 0;
-
- id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
- id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
- id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
-
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
- const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
-
- const uint64_t nb30 = src3->nb[0];
- const uint64_t nb31 = src3->nb[1];
-
- const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
- const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
-
- const uint64_t nb40 = src4->nb[0];
- const uint64_t nb41 = src4->nb[1];
- const uint64_t nb42 = src4->nb[2];
-
- const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
- const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
- const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
-
- const uint64_t nb50 = src5->nb[0];
- const uint64_t nb51 = src5->nb[1];
- const uint64_t nb52 = src5->nb[2];
-
- const int64_t d_state = ne00;
- const int64_t d_inner = ne01;
- const int64_t n_seq_tokens = ne11;
- const int64_t n_seqs = ne02;
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
- [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
- [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
-
- [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
- [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
- [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
- [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
-
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
- [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
- [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
- [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
- [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
- [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
- [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
- [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
- [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
-
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_MUL_MAT:
- {
- GGML_ASSERT(ne00 == ne10);
-
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- const uint r2 = ne12/ne02;
- const uint r3 = ne13/ne03;
-
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- int ne11_mm_min = 1;
-
-#if 0
- // the numbers below are measured on M2 Ultra for 7B and 13B models
- // these numbers do not translate to other devices or model sizes
- // TODO: need to find a better approach
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
- }
-#endif
-
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([device supportsFamily:MTLGPUFamilyApple7] &&
- !ggml_is_transposed(src0) &&
- !ggml_is_transposed(src1) &&
- src1t == GGML_TYPE_F32 &&
- ne00 % 32 == 0 && ne00 >= 64 &&
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- // some Metal matrix data types require aligned pointers
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
- case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
- default: break;
- }
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
- default: GGML_ABORT("MUL MAT-MAT not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
- } else {
- int nth0 = 32;
- int nth1 = 1;
- int nrows = 1;
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- // use custom matrix x vector kernel
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
- nrows = 4;
- } break;
- case GGML_TYPE_F16:
- {
- nth0 = 32;
- nth1 = 1;
- if (src1t == GGML_TYPE_F32) {
- if (ne11 * ne12 < 4) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
- nrows = ne11;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
- nrows = 4;
- }
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
- nrows = 4;
- }
- } break;
- case GGML_TYPE_BF16:
- {
- nth0 = 32;
- nth1 = 1;
- if (src1t == GGML_TYPE_F32) {
- if (ne11 * ne12 < 4) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
- nrows = ne11;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
- nrows = 4;
- }
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
- nrows = 4;
- }
- } break;
- case GGML_TYPE_Q4_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q8_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q2_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q3_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_K:
- {
- nth0 = 4; //1;
- nth1 = 8; //32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_M:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_NL:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
- } break;
- default:
- {
- GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
- GGML_ABORT("not implemented");
- }
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
-
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
- const int mem_size = 32*sizeof(float);
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q3_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- } else {
- const int64_t ny = (ne11 + nrows - 1)/nrows;
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- }
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- const int n_as = src0->ne[2];
-
- // src2 = ids
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
- GGML_ASSERT(src2t == GGML_TYPE_I32);
-
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
-
- GGML_ASSERT(src1t == GGML_TYPE_F32);
-
- GGML_ASSERT(ne03 == 1);
- GGML_ASSERT(ne13 == 1);
-
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- // ne20 = n_used_experts
- // ne21 = n_rows
- const int dst_rows = ne20*ne21;
- const int dst_rows_min = n_as;
- const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
-
- // max size of the rowids array in the kernel shared buffer
- GGML_ASSERT(dst_rows <= dst_rows_max);
-
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- // !!!
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
- // indirect matrix multiplication
- // !!!
- if ([device supportsFamily:MTLGPUFamilyApple7] &&
- ne00 % 32 == 0 && ne00 >= 64 &&
- dst_rows > dst_rows_min) {
- // some Metal matrix data types require aligned pointers
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
- case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
- default: break;
- }
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
- default: GGML_ABORT("MUL_MAT_ID not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
-
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
- } else {
- int nth0 = 32;
- int nth1 = 1;
- int nrows = 1;
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- // use custom matrix x vector kernel
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
- } break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- nth0 = 32;
- nth1 = 1;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
- } break;
- case GGML_TYPE_BF16:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- nth0 = 32;
- nth1 = 1;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q8_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q2_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q3_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_K:
- {
- nth0 = 4; //1;
- nth1 = 8; //32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_M:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_NL:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
- } break;
- default:
- {
- GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
- GGML_ABORT("not implemented");
- }
- };
-
- if (ggml_is_quantized(src0t)) {
- GGML_ASSERT(ne00 >= nth0*nth1);
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
-
- const int64_t _ne1 = 1;
- const int tgz = dst_rows;
-
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
- const int mem_size = 32*sizeof(float);
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q3_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- } else {
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- }
- } break;
- case GGML_OP_GET_ROWS:
- {
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
- default: GGML_ABORT("not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
- } break;
- case GGML_OP_RMS_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(src0));
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- int nth = 32; // SIMD width
-
- while (nth < ne00/4 && nth < 1024) {
- nth *= 2;
- }
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_GROUP_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- float eps;
- memcpy(&eps, dst->op_params + 1, sizeof(float));
-
- const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
-
- int nth = 32; // SIMD width
-
- //while (nth < ne00/4 && nth < 1024) {
- // nth *= 2;
- //}
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_NORM:
- {
- GGML_ASSERT(ggml_is_contiguous_1(src0));
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- const int nth = MIN(256, ne00);
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ROPE:
- {
- GGML_ASSERT(ne10 == ne02);
-
- const int nth = MIN(1024, ne00);
-
- const int n_past = ((const int32_t *) dst->op_params)[0];
- const int n_dims = ((const int32_t *) dst->op_params)[1];
- const int mode = ((const int32_t *) dst->op_params)[2];
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
- const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
-
- float freq_base;
- float freq_scale;
- float ext_factor;
- float attn_factor;
- float beta_fast;
- float beta_slow;
-
- memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
-
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (!is_neox) {
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
- } else {
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- if (id_src2 != nil) {
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
- [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_IM2COL:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int32_t N = src1->ne[is_2D ? 3 : 2];
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
- const int32_t IH = is_2D ? src1->ne[1] : 1;
- const int32_t IW = src1->ne[0];
-
- const int32_t KH = is_2D ? src0->ne[1] : 1;
- const int32_t KW = src0->ne[0];
-
- const int32_t OH = is_2D ? dst->ne[2] : 1;
- const int32_t OW = dst->ne[1];
-
- const int32_t CHW = IC * KH * KW;
-
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
-
- const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
-
- switch (dst->type) {
- case GGML_TYPE_F32: {
- pipeline = (is_gt_mttpt ?
- ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
- :
- ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
- } break;
- case GGML_TYPE_F16: {
- pipeline = (is_gt_mttpt ?
- ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
- :
- ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
- } break;
- default: GGML_ABORT("fatal error");
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
- [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
- [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
- [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
- [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
- [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
- [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
-
- if (is_gt_mttpt) {
- [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
- [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
- [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
-
- const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
-
- const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
- } else {
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
- }
- } break;
- case GGML_OP_UPSCALE:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const float sf0 = (float)ne0/src0->ne[0];
- const float sf1 = (float)ne1/src0->ne[1];
- const float sf2 = (float)ne2/src0->ne[2];
- const float sf3 = (float)ne3/src0->ne[3];
-
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_PAD:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ARANGE:
- {
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
- float start;
- float step;
-
- memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
- memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const int dim = dst->op_params[0];
- const int max_period = dst->op_params[1];
-
- const int half = dim / 2;
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
-
- const int nth = MIN(1024, half);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ARGSORT:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
- const int nrows = ggml_nrows(src0);
-
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
- // bitonic sort requires the number of elements to be power of 2
- int64_t ne00_padded = 1;
- while (ne00_padded < ne00) {
- ne00_padded *= 2;
- }
-
- // Metal kernels require the buffer size to be multiple of 16 bytes
- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
- const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (order) {
- case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
- case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
- } break;
- case GGML_OP_LEAKY_RELU:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- float slope;
- memcpy(&slope, dst->op_params, sizeof(float));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_FLASH_ATTN_EXT:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ne11 % 32 == 0);
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == src2->type);
-
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
-
- struct ggml_tensor * src3 = node->src[3];
-
- size_t offs_src3 = 0;
-
- id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
-
- GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
- GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
-
- const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
- //const int64_t ne31 = src3 ? src3->ne[1] : 0;
- const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
- const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
-
- const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
- const uint64_t nb31 = src3 ? src3->nb[1] : 0;
- const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
- const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
-
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
- float scale;
- float max_bias;
- float logit_softcap;
- memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
- memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
-
- if (logit_softcap != 0.0f) {
- scale /= logit_softcap;
- }
-
- const uint32_t n_head = src0->ne[2];
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- bool use_vec_kernel = false;
-
- // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
- // for now avoiding mainly to keep the number of templates/kernels a bit lower
- if (ne01 >= 4 || (ne00%128 != 0)) {
- switch (src1->type) {
- case GGML_TYPE_F16:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_BF16:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_Q4_0:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_Q4_1:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_Q5_0:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_Q5_1:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- case GGML_TYPE_Q8_0:
- {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- use_vec_kernel = true;
-
- switch (ne00) {
- case 128:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 256:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- if (id_src3) {
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
- [encoder setBytes:&scale length:sizeof( float) atIndex:20];
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
- [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
-
- if (!use_vec_kernel) {
- // half8x8 kernel
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
- GGML_ASSERT(nqptg <= 32);
- GGML_ASSERT(nqptg % 8 == 0);
- GGML_ASSERT(ncpsg % 32 == 0);
-
- // 2*(2*ncpsg + nqptg)*(nsg)
- // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
- //
- // 16*32*(nsg)
- // the shared memory needed for the simdgroups to load the KV cache
- // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
- //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
-
- int64_t nsgmax = 2;
-
- while (true) {
- const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
-
- // simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
-
- const size_t smem = FATTN_SMEM(nsg);
-
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
- } else {
- // half4x4 kernel
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
- GGML_ASSERT(nqptg <= 32);
- GGML_ASSERT(nqptg % 1 == 0);
- GGML_ASSERT(ncpsg % 32 == 0);
-
- // ne00 + 2*ncpsg*(nsg)
- // for each query, we load it as f16 in shared memory (ne00)
- // and store the soft_max values and the mask
- //
- // ne00*(nsg)
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
- //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
-
- int64_t nsgmax = 2;
-
- while (true) {
- const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
-
- // simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
-
- int64_t nsg = 1;
- while (nsg <= nsgt) {
- nsg *= 2;
- }
- nsg /= 2;
-
- const size_t smem = FATTN_SMEM(nsg);
-
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
- }
- } break;
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- {
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
-
- switch (dstt) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
- default: GGML_ABORT("not implemented");
- };
- } break;
- case GGML_TYPE_F16:
- {
- switch (dstt) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
- default: GGML_ABORT("not implemented");
- };
- } break;
- case GGML_TYPE_BF16:
- {
- switch (dstt) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
- };
- } break;
- default: GGML_ABORT("not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_POOL_2D:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
-
- const int32_t * opts = dst->op_params;
- enum ggml_op_pool op = opts[0];
-
- id<MTLComputePipelineState> pipeline = nil;
- switch (src0t) {
- case GGML_TYPE_F32: {
- switch(op) {
- case GGML_OP_POOL_AVG:
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
- case GGML_OP_POOL_MAX:
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
- }
- } break;
- default: GGML_ASSERT(false && "not implemented");
- }
-
- const int32_t k0 = opts[1];
- const int32_t k1 = opts[2];
- const int32_t s0 = opts[3];
- const int32_t s1 = opts[4];
- const int32_t p0 = opts[5];
- const int32_t p1 = opts[6];
-
- const int64_t IH = src0->ne[1];
- const int64_t IW = src0->ne[0];
-
- const int64_t N = dst->ne[3];
- const int64_t OC = dst->ne[2];
- const int64_t OH = dst->ne[1];
- const int64_t OW = dst->ne[0];
-
- const int64_t parallel_elements = N * OC * OH * OW;
- const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
- const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
- [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
- [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
- [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
- [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
- [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
- [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
- } break;
- default:
- {
- GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
- GGML_ABORT("fatal error");
- }
- }
-}
-
-static enum ggml_status ggml_metal_graph_compute(
- ggml_backend_t backend,
- struct ggml_cgraph * gf) {
- struct ggml_backend_metal_context * ctx = backend->context;
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
- // number of nodes encoded by the main thread (empirically determined)
- const int n_main = 128;
-
- // number of threads in addition to the main thread
- const int n_cb = ctx->n_cb;
-
- // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
- // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
- // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
- // each thread creates it's own command buffer and enqueues the ops in parallel
- //
- // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
-
- @autoreleasepool {
- ctx->gf = gf;
-
- ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
- ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
-
- ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
-
- const bool should_capture = ctx->capture_next_compute;
- if (should_capture) {
- ctx->capture_next_compute = false;
-
- if (!ctx->capture_started) {
- // create capture scope
- ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
-
- MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
- descriptor.captureObject = ctx->capture_scope;
- descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
- descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
-
- NSError * error = nil;
- if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
- GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
- } else {
- [ctx->capture_scope beginScope];
- ctx->capture_started = true;
- }
- }
- }
-
- // the main thread commits the first few commands immediately
- // command_buffer[n_cb]
- {
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
- ctx->command_buffers[n_cb] = command_buffer;
-
- [command_buffer enqueue];
- ctx->encode_async(n_cb);
- }
-
- // prepare the rest of the command buffers asynchronously
- // command_buffer[0.. n_cb)
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
- ctx->command_buffers[cb_idx] = command_buffer;
-
- // always enqueue the first two command buffers
- // enqueue all of the command buffers if we don't need to abort
- if (cb_idx < 2 || ctx->abort_callback == NULL) {
- [command_buffer enqueue];
- }
- }
-
- dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
-
- // wait for completion and check status of each command buffer
- // needed to detect if the device ran out-of-memory for example (#1881)
- {
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
- [command_buffer waitUntilCompleted];
-
- MTLCommandBufferStatus status = [command_buffer status];
- if (status != MTLCommandBufferStatusCompleted) {
- GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
- if (status == MTLCommandBufferStatusError) {
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
- }
-
- return GGML_STATUS_FAILED;
- }
- }
-
- for (int i = 0; i < n_cb; ++i) {
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
- [command_buffer waitUntilCompleted];
-
- MTLCommandBufferStatus status = [command_buffer status];
- if (status != MTLCommandBufferStatusCompleted) {
- GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
- if (status == MTLCommandBufferStatusError) {
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
- }
-
- return GGML_STATUS_FAILED;
- }
-
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
- if (!next_buffer) {
- continue;
- }
-
- const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
- if (next_queued) {
- continue;
- }
-
- if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
- GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
- return GGML_STATUS_ABORTED;
- }
-
- [next_buffer commit];
- }
-
- if (!should_capture && ctx->capture_started) {
- [ctx->capture_scope endScope];
- [[MTLCaptureManager sharedCaptureManager] stopCapture];
- }
- }
-
- return GGML_STATUS_SUCCESS;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend interface
-
-static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
- for (int i = 0; i < ctx->n_buffers; i++) {
- [ctx->buffers[i].metal release];
- }
- ggml_backend_metal_device_rel(buffer->buft->device->context);
-
- if (ctx->owned) {
-#if TARGET_OS_OSX
- vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
-#else
- free(ctx->all_data);
-#endif
- }
-
- free(ctx);
-}
-
-static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
- return ctx->all_data;
-}
-
-static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- memcpy((char *)tensor->data + offset, data, size);
-
- UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- memcpy(data, (const char *)tensor->data + offset, size);
-
- UNUSED(buffer);
-}
-
-static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
- if (ggml_backend_buffer_is_host(src->buffer)) {
- memcpy(dst->data, src->data, ggml_nbytes(src));
- return true;
- }
- return false;
-
- UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
- memset(ctx->all_data, value, ctx->all_size);
-}
-
-static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
- /* .init_tensor = */ NULL,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_metal_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// default buffer type
-
-static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "Metal";
-
- UNUSED(buft);
-}
-
-static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
-#ifndef GGML_METAL_NDEBUG
-#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
- if (@available(macOS 10.12, iOS 16.0, *)) {
- GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
- __func__,
- size_aligned / 1024.0 / 1024.0,
- device.currentAllocatedSize / 1024.0 / 1024.0,
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
- GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
- }
- } else {
- GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
- __func__,
- size_aligned / 1024.0 / 1024.0,
- device.currentAllocatedSize / 1024.0 / 1024.0);
- }
-#endif
-#endif
- UNUSED(device);
- UNUSED(size_aligned);
-}
-
-static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
- const size_t size_page = sysconf(_SC_PAGESIZE);
-
- size_t size_aligned = size;
- if ((size_aligned % size_page) != 0) {
- size_aligned += (size_page - (size_aligned % size_page));
- }
-
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
-
- ctx->all_data = ggml_metal_host_malloc(size_aligned);
- ctx->all_size = size_aligned;
- ctx->owned = true;
- ctx->n_buffers = 1;
-
- if (ctx->all_data != NULL) {
- ctx->buffers[0].data = ctx->all_data;
- ctx->buffers[0].size = size;
- ctx->buffers[0].metal = nil;
-
- if (size_aligned > 0) {
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
- length:size_aligned
- options:MTLResourceStorageModeShared
- deallocator:nil];
- }
- }
-
- if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
- free(ctx);
- ggml_backend_metal_device_rel(buft->device->context);
- return NULL;
- }
-
- //ggml_backend_metal_log_allocated_size(device, size_aligned);
-
- return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
-}
-
-static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return 32;
- UNUSED(buft);
-}
-
-static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
- const size_t max_size = device.maxBufferLength;
- ggml_backend_metal_device_rel(buft->device->context);
-
- return max_size;
-
- UNUSED(buft);
-}
-
-static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return true;
-
- UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
- static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
- },
- /* .device = */ &g_ggml_backend_metal_device,
- /* .context = */ NULL,
- };
-
- return &ggml_backend_buffer_type_metal;
-}
-
-static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
- return "Metal_Mapped";
-
- UNUSED(buft);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
- static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
- },
- /* .device = */ &g_ggml_backend_metal_device,
- /* .context = */ NULL,
- };
-
- return &ggml_backend_buffer_from_ptr_type_metal;
-}
-
-// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
-ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
- struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
- ctx->all_data = data;
- ctx->all_size = size;
- ctx->owned = false;
- ctx->n_buffers = 0;
-
- const size_t size_page = sysconf(_SC_PAGESIZE);
-
- // page-align the data ptr
- {
- const uintptr_t offs = (uintptr_t) data % size_page;
- data = (void *) ((char *) data - offs);
- size += offs;
- }
-
- size_t size_aligned = size;
- if ((size_aligned % size_page) != 0) {
- size_aligned += (size_page - (size_aligned % size_page));
- }
-
- id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
-
- // the buffer fits into the max buffer size allowed by the device
- if (size_aligned <= device.maxBufferLength) {
- ctx->buffers[ctx->n_buffers].data = data;
- ctx->buffers[ctx->n_buffers].size = size;
- ctx->buffers[ctx->n_buffers].metal = nil;
-
- if (size_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
-
- ggml_backend_metal_log_allocated_size(device, size_aligned);
-
- ++ctx->n_buffers;
- } else {
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
- // one of the views
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
- const size_t size_step = device.maxBufferLength - size_ovlp;
- const size_t size_view = device.maxBufferLength;
-
- for (size_t i = 0; i < size; i += size_step) {
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
-
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
- ctx->buffers[ctx->n_buffers].metal = nil;
-
- if (size_step_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
-
- ggml_backend_metal_log_allocated_size(device, size_step_aligned);
-
- if (i + size_step < size) {
- GGML_LOG_INFO("\n");
- }
-
- ++ctx->n_buffers;
- }
- }
-
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
-}
-
-// backend
-
-static const char * ggml_backend_metal_name(ggml_backend_t backend) {
- return "Metal";
-
- UNUSED(backend);
-}
-
-static void ggml_backend_metal_free(ggml_backend_t backend) {
- struct ggml_backend_metal_context * ctx = backend->context;
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
- ggml_backend_metal_device_rel(ctx_dev);
- ggml_metal_free(ctx);
-
- free(backend);
-}
-
-static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- return ggml_metal_graph_compute(backend, cgraph);
-}
-
-static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
- GGML_ASSERT(ggml_backend_is_metal(backend));
-
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-
- if (ctx->n_cb != n_cb) {
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
-
- if (ctx->n_cb > 2) {
- GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
- }
- }
-
- if (ctx->encode_async) {
- Block_release(ctx->encode_async);
- }
-
- ctx->encode_async = Block_copy(^(size_t iter) {
- const int cb_idx = iter;
- const int n_cb_l = ctx->n_cb;
-
- const int n_nodes_0 = ctx->n_nodes_0;
- const int n_nodes_1 = ctx->n_nodes_1;
-
- const int n_nodes_per_cb = ctx->n_nodes_per_cb;
-
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
-
- int node_start = 0;
- int node_end = n_nodes_0;
-
- if (cb_idx < n_cb_l) {
- node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
- node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
- }
-
- const bool should_capture = ctx->capture_next_compute;
-
- for (int idx = node_start; idx < node_end; ++idx) {
- if (should_capture) {
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
- }
-
- ggml_metal_encode_node(backend, idx, encoder);
-
- if (should_capture) {
- [encoder popDebugGroup];
- }
- }
-
- [encoder endEncoding];
-
- if (cb_idx < 2 || ctx->abort_callback == NULL) {
- [command_buffer commit];
- }
- });
-}
-
-static struct ggml_backend_i ggml_backend_metal_i = {
- /* .get_name = */ ggml_backend_metal_name,
- /* .free = */ ggml_backend_metal_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ NULL,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_metal_guid(void) {
- static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
- return &guid;
-}
-
-// TODO: remove in the future
-ggml_backend_t ggml_backend_metal_init(void) {
- ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
-
- struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
- if (ctx == NULL) {
- GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
- return NULL;
- }
-
- ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
-
- *backend = (struct ggml_backend) {
- /* .guid = */ ggml_backend_metal_guid(),
- /* .interface = */ ggml_backend_metal_i,
- /* .device = */ dev,
- /* .context = */ ctx,
- };
-
- ggml_backend_metal_set_n_cb(backend, 1);
-
- return backend;
-}
-
-bool ggml_backend_is_metal(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
-}
-
-void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
- GGML_ASSERT(ggml_backend_is_metal(backend));
-
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-
- ctx->abort_callback = abort_callback;
- ctx->abort_callback_data = user_data;
-}
-
-bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
- GGML_ASSERT(ggml_backend_is_metal(backend));
-
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
- return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
-}
-
-void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
- GGML_ASSERT(ggml_backend_is_metal(backend));
-
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
- ctx->capture_next_compute = true;
-}
-
-// backend device
-
-static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
- return "Metal";
-
- GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
- // acq/rel just to populate ctx->name in case it hasn't been done yet
- struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
- ggml_backend_metal_device_acq(ctx_dev);
- ggml_backend_metal_device_rel(ctx_dev);
-
- return ctx_dev->name;
-}
-
-static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- if (@available(macOS 10.12, iOS 16.0, *)) {
- struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
-
- *total = device.recommendedMaxWorkingSetSize;
- *free = *total - device.currentAllocatedSize;
-
- ggml_backend_metal_device_rel(ctx_dev);
- } else {
- *free = 1;
- *total = 1;
- }
-}
-
-static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-
- GGML_UNUSED(dev);
-}
-
-static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_metal_device_get_name(dev);
- props->description = ggml_backend_metal_device_get_description(dev);
- props->type = ggml_backend_metal_device_get_type(dev);
- ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
- props->caps = (struct ggml_backend_dev_caps) {
- /* .async = */ false,
- /* .host_buffer = */ false,
- /* .buffer_from_host_ptr = */ true,
- /* .events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
- struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
- if (ctx == NULL) {
- GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
- return NULL;
- }
-
- ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
-
- *backend = (struct ggml_backend) {
- /* .guid = */ ggml_backend_metal_guid(),
- /* .interface = */ ggml_backend_metal_i,
- /* .device = */ dev,
- /* .context = */ ctx,
- };
-
- ggml_backend_metal_set_n_cb(backend, 1);
-
- return backend;
-
- GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
- return ggml_backend_metal_buffer_type();
-
- GGML_UNUSED(dev);
-}
-
-static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
- struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
- ctx->all_data = ptr;
- ctx->all_size = size;
- ctx->owned = false;
- ctx->n_buffers = 0;
-
- const size_t size_page = sysconf(_SC_PAGESIZE);
-
- // page-align the data ptr
- {
- const uintptr_t offs = (uintptr_t) ptr % size_page;
- ptr = (void *) ((char *) ptr - offs);
- size += offs;
- }
-
- size_t size_aligned = size;
- if ((size_aligned % size_page) != 0) {
- size_aligned += (size_page - (size_aligned % size_page));
- }
-
- struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
-
- // the buffer fits into the max buffer size allowed by the device
- if (size_aligned <= device.maxBufferLength) {
- ctx->buffers[ctx->n_buffers].data = ptr;
- ctx->buffers[ctx->n_buffers].size = size;
- ctx->buffers[ctx->n_buffers].metal = nil;
-
- if (size_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
-
- ggml_backend_metal_log_allocated_size(device, size_aligned);
-
- ++ctx->n_buffers;
- } else {
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
- // one of the views
- const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
- const size_t size_step = device.maxBufferLength - size_ovlp;
- const size_t size_view = device.maxBufferLength;
-
- for (size_t i = 0; i < size; i += size_step) {
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
-
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
- ctx->buffers[ctx->n_buffers].metal = nil;
-
- if (size_step_aligned > 0) {
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
- GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
- return false;
- }
- }
-
- ggml_backend_metal_log_allocated_size(device, size_step_aligned);
-
- if (i + size_step < size) {
- GGML_LOG_INFO("\n");
- }
-
- ++ctx->n_buffers;
- }
- }
-
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
-}
-
-static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- struct ggml_backend_metal_device_context * ctx_dev = dev->context;
-
- return ggml_metal_supports_op(ctx_dev, op);
-}
-
-static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
-
- UNUSED(dev);
-}
-
-static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- return false;
-
- GGML_UNUSED(dev);
- GGML_UNUSED(op);
-}
-
-static struct ggml_backend_device_i ggml_backend_metal_device_i = {
- /* .get_name = */ ggml_backend_metal_device_get_name,
- /* .get_description = */ ggml_backend_metal_device_get_description,
- /* .get_memory = */ ggml_backend_metal_device_get_memory,
- /* .get_type = */ ggml_backend_metal_device_get_type,
- /* .get_props = */ ggml_backend_metal_device_get_props,
- /* .init_backend = */ ggml_backend_metal_device_init,
- /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
- /* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
- /* .supports_op = */ ggml_backend_metal_device_supports_op,
- /* .supports_buft = */ ggml_backend_metal_device_supports_buft,
- /* .offload_op = */ ggml_backend_metal_device_offload_op,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-// backend registry
-
-static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
- return "Metal";
-
- GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
- return 1;
-
- GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
- GGML_ASSERT(index == 0);
-
- return &g_ggml_backend_metal_device;
-
- GGML_UNUSED(reg);
- GGML_UNUSED(index);
-}
-
-static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
- /* .get_name = */ ggml_backend_metal_reg_get_name,
- /* .device_count = */ ggml_backend_metal_reg_device_count,
- /* .device_get = */ ggml_backend_metal_reg_device_get,
- /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_metal_reg(void) {
- // TODO: make this thread-safe somehow?
- {
- g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
- /* .iface = */ ggml_backend_metal_reg_i,
- /* .context = */ NULL,
- };
-
- g_ggml_backend_metal_device = (struct ggml_backend_device) {
- /* .iface = */ ggml_backend_metal_device_i,
- /* .reg = */ &g_ggml_backend_metal_reg,
- /* .context = */ &g_ggml_ctx_dev_main,
- };
- }
-
- return &g_ggml_backend_metal_reg;
-}
+++ /dev/null
-#define GGML_COMMON_DECL_METAL
-#define GGML_COMMON_IMPL_METAL
-#include "ggml-common.h"
-
-#include <metal_stdlib>
-
-using namespace metal;
-
-#define MAX(x, y) ((x) > (y) ? (x) : (y))
-#define MIN(x, y) ((x) < (y) ? (x) : (y))
-#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
-
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-
-// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-//
-// cmd:
-// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal.metal
-// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
-//
-#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
-#undef GGML_METAL_USE_BF16
-#endif
-
-#if defined(GGML_METAL_USE_BF16)
-typedef matrix<bfloat, 4, 4> bfloat4x4;
-#endif
-
-constexpr constant static float kvalues_iq4nl_f[16] = {
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
-// NOTE: this is not dequantizing - we are simply fitting the template
-template <typename type4x4>
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
-}
-
-template <typename type4x4>
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
-}
-
-#if defined(GGML_METAL_USE_BF16)
-template <typename type4x4>
-void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
-}
-#endif
-
-template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float md = -8.h * xb->d;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
-
- float4x4 reg_f;
-
- for (int i = 0; i < 8; i++) {
- reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
- reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
- }
-
- reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float m = xb->m;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
-
- float4x4 reg_f;
-
- for (int i = 0; i < 8; i++) {
- reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
- reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
- }
-
- reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
- const float d = xb->d;
- const float md = -16.h * xb->d;
- const ushort mask = il ? 0x00F0 : 0x000F;
-
- const uint32_t qh = *((device const uint32_t *)xb->qh);
-
- const int x_mv = il ? 4 : 0;
-
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
-
- float4x4 reg_f;
-
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
- reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
- reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
- }
-
- reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
- const float d = xb->d;
- const float m = xb->m;
- const ushort mask = il ? 0x00F0 : 0x000F;
-
- const uint32_t qh = *((device const uint32_t *)xb->qh);
-
- const int x_mv = il ? 4 : 0;
-
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
-
- float4x4 reg_f;
-
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
- reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
- reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
- }
-
- reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
- device const int8_t * qs = ((device const int8_t *)xb->qs);
- const half d = xb->d;
-
- float4x4 reg_f;
-
- for (int i = 0; i < 16; i++) {
- reg_f[i/4][i%4] = (qs[i + 16*il] * d);
- }
-
- reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const float d = xb->d;
- const float min = xb->dmin;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- float dl, ml;
- uint8_t sc = xb->scales[il];
-
- q = q + 32*(il/8) + 16*(il&1);
- il = (il/2)%4;
-
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
- device const int8_t * scales = (device const int8_t *)xb->scales;
-
- q = q + 32 * (il/8) + 16 * (il&1);
- h = h + 16 * (il&1);
- uint8_t m = 1 << (il/2);
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
- ((il/4)>0 ? 12 : 3);
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
- const float ml = 4.f * dl;
-
- il = (il/2) & 3;
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl *= coef;
-
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
- }
-}
-
-static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
-}
-
-template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
- device const uchar * q = xb->qs;
-
- short is = (il/4) * 2;
- q = q + (il/4) * 32 + 16 * (il&1);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.h;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
-
- const ushort mask = il<2 ? 0x0F : 0xF0;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
- device const uint8_t * q = xb->qs;
- device const uint8_t * qh = xb->qh;
-
- short is = (il/4) * 2;
- q = q + 32 * (il/4) + 16 * (il&1);
- qh = qh + 16 * (il&1);
- uint8_t ul = 1 << (il/2);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.f;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
-
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const float qh_val = il<2 ? 16.f : 256.f;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
- device const int8_t * scales = (device const int8_t *)xb->scales;
-
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
- qh = qh + 32*(il/8) + 16*(il&1);
- float sc = scales[(il%2) + 2 * ((il/2))];
- il = (il/2) & 3;
-
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
- const float coef = il>1 ? 1.f/16.f : 1.f;
- const float ml = d_all * sc * 32.f;
- const float dl = d_all * sc * coef;
- for (int i = 0; i < 16; ++i) {
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
- reg[i/4][i%4] = dl * q - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const uint32_t aux32_g = q2[0] | (q2[1] << 16);
- const uint32_t aux32_s = q2[2] | (q2[3] << 16);
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
- const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
- uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
- signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
- uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
- signs = ksigns_iq2xs[q2[2*il+1] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * q3 = xb->qs + 8*ib32;
- device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
- const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
- uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
- grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
- grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
- signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 8*ib32;
- device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
- constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
- reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
- }
- grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
- grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
- reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * signs = qs + QK_K/8;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
- reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- const float d = xb->d;
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint16_t * qh = xb->qh;
- const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
- const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
- const uint16_t h = qh[ib32] >> 6*il;
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml;
- reg[1][i] = dl * (grid1[i] >> 4) + ml;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml;
- reg[3][i] = dl * (grid2[i] >> 4) + ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- device const uint16_t * sc = (device const uint16_t *)xb->scales;
-
- iq1m_scale_t scale;
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
- const float d = scale.f16;
-
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * qh = xb->qh + 2*ib32 + il;
-
- const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
- const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
- reg[1][i] = dl * (grid1[i] >> 4) + ml1;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
- reg[3][i] = dl * (grid2[i] >> 4) + ml2;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
- device const uint16_t * q4 = (device const uint16_t *)xb->qs;
- const float d = xb->d;
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
-}
-
-template <typename type4x4>
-void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
- const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
- const float d = (float)xb->d * (ls - 32);
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
-}
-
-enum ggml_sort_order {
- GGML_SORT_ORDER_ASC,
- GGML_SORT_ORDER_DESC,
-};
-
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-kernel void kernel_add(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
- }
-}
-
-kernel void kernel_sub(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
- }
-}
-
-kernel void kernel_mul(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
- }
-}
-
-kernel void kernel_div(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
- }
-}
-
-template<typename T>
-kernel void kernel_repeat(
- device const char * src0,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
-
- const int64_t i03 = i3 % ne03;
- const int64_t i02 = i2 % ne02;
- const int64_t i01 = i1 % ne01;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i00 = i0 % ne00;
- *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
- }
-}
-
-typedef decltype(kernel_repeat<float>) kernel_repeat_t;
-
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_add_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] + src1[tpig % nb];
-}
-
-kernel void kernel_sub_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] - src1[tpig % nb];
-}
-
-kernel void kernel_mul_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src1[tpig % nb];
-}
-
-kernel void kernel_div_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] / src1[tpig % nb];
-}
-
-kernel void kernel_scale(
- device const float * src0,
- device float * dst,
- constant float & scale,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * scale;
-}
-
-kernel void kernel_scale_4(
- device const float4 * src0,
- device float4 * dst,
- constant float & scale,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * scale;
-}
-
-kernel void kernel_clamp(
- device const float * src0,
- device float * dst,
- constant float & min,
- constant float & max,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
-}
-
-kernel void kernel_relu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = precise::tanh(x);
-}
-
-constant float GELU_COEF_A = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
-
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
-
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_silu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_sqr(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_sum_rows(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tpig[[thread_position_in_grid]]) {
- int64_t i3 = tpig.z;
- int64_t i2 = tpig.y;
- int64_t i1 = tpig.x;
-
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
- return;
- }
-
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
-
- float row_sum = 0;
-
- for (int64_t i0 = 0; i0 < ne00; i0++) {
- row_sum += src_row[i0];
- }
-
- dst_row[0] = row_sum;
-}
-
-template<typename T>
-kernel void kernel_soft_max(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (ne02*ne01);
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
-
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-
- float slope = 1.0f;
-
- // ALiBi
- if (max_bias > 0.0f) {
- const int64_t h = i02;
-
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- // parallel max
- float lmax = -INFINITY;
-
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
- }
-
- // find the max value in the block
- float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = -INFINITY;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = max_val;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- max_val = buf[tiisg];
- max_val = simd_max(max_val);
- }
-
- // parallel sum
- float lsum = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
- lsum += exp_psrc0;
- pdst[i00] = exp_psrc0;
- }
-
- // This barrier fixes a failing test
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
- threadgroup_barrier(mem_flags::mem_none);
-
- float sum = simd_sum(lsum);
-
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- sum = buf[tiisg];
- sum = simd_sum(sum);
- }
-
- const float inv_sum = 1.0f/sum;
-
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- pdst[i00] *= inv_sum;
- }
-}
-
-template<typename T>
-kernel void kernel_soft_max_4(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (ne02*ne01);
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
-
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-
- float slope = 1.0f;
-
- if (max_bias > 0.0f) {
- const int64_t h = i02;
-
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- // parallel max
- float4 lmax4 = -INFINITY;
-
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
- }
-
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-
- float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = -INFINITY;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = max_val;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- max_val = buf[tiisg];
- max_val = simd_max(max_val);
- }
-
- // parallel sum
- float4 lsum4 = 0.0f;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
- lsum4 += exp_psrc4;
- pdst4[i00] = exp_psrc4;
- }
-
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
-
- // This barrier fixes a failing test
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
- threadgroup_barrier(mem_flags::mem_none);
-
- float sum = simd_sum(lsum);
-
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- sum = buf[tiisg];
- sum = simd_sum(sum);
- }
-
- const float inv_sum = 1.0f/sum;
-
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- pdst4[i00] *= inv_sum;
- }
-}
-
-typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
-typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
-
-template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
-template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
-template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
-template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
-
-kernel void kernel_diag_mask_inf(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int & n_past,
- uint3 tpig[[thread_position_in_grid]]) {
- const int64_t i02 = tpig[2];
- const int64_t i01 = tpig[1];
- const int64_t i00 = tpig[0];
-
- if (i00 > n_past + i01) {
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
- } else {
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
- }
-}
-
-kernel void kernel_diag_mask_inf_8(
- device const float4 * src0,
- device float4 * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int & n_past,
- uint3 tpig[[thread_position_in_grid]]) {
-
- const int64_t i = 2*tpig[0];
-
- dst[i+0] = src0[i+0];
- dst[i+1] = src0[i+1];
- int64_t i4 = 4*i;
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
- const int64_t i00 = i4;
- for (int k = 3; k >= 0; --k) {
- if (i00 + 4 + k <= n_past + i01) {
- break;
- }
- dst[i+1][k] = -INFINITY;
- if (i00 + k > n_past + i01) {
- dst[i][k] = -INFINITY;
- }
- }
-}
-
-// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
-kernel void kernel_ssm_conv_f32(
- device const void * src0,
- device const void * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i2 = tgpig.y;
- const int64_t i3 = tgpig.z;
-
- const int64_t nc = ne10;
- //const int64_t ncs = ne00;
- //const int64_t nr = ne01;
- //const int64_t n_t = ne1;
- //const int64_t n_s = ne2;
-
- device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
- device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
- device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
-
- float sumf = 0.0f;
-
- for (int64_t i0 = 0; i0 < nc; ++i0) {
- sumf += s[i0] * c[i0];
- }
-
- x[0] = sumf;
-}
-
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
-kernel void kernel_ssm_scan_f32(
- device const void * src0,
- device const void * src1,
- device const void * src2,
- device const void * src3,
- device const void * src4,
- device const void * src5,
- device float * dst,
- constant int64_t & d_state,
- constant int64_t & d_inner,
- constant int64_t & n_seq_tokens,
- constant int64_t & n_seqs,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb20,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb30,
- constant uint64_t & nb31,
- constant uint64_t & nb40,
- constant uint64_t & nb41,
- constant uint64_t & nb42,
- constant uint64_t & nb50,
- constant uint64_t & nb51,
- constant uint64_t & nb52,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i3 = tgpig.y;
-
- const int64_t nc = d_state;
- //const int64_t nr = d_inner;
- const int64_t n_t = n_seq_tokens;
- //const int64_t n_s = n_seqs;
-
- for (int64_t i2 = 0; i2 < n_t; ++i2) {
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
- device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
- device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
- device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
- device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
- device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
-
- if (i2 > 0) {
- s0 = s;
- }
-
- // i1 == 0
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
- float x_dt = x[0] * dt_soft_plus;
- float sumf = 0.0f;
-
- for (int64_t i0 = 0; i0 < nc; ++i0) {
- int64_t i = i0;
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
- sumf += state * C[i0];
- s[i] = state;
- }
-
- y[0] = sumf;
- }
-}
-
-kernel void kernel_norm(
- device const void * src0,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant float & eps,
- threadgroup float * sum [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
- // MEAN
- // parallel sum
- sum[tpitg] = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- sum[tpitg] += x[i00];
- }
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg/2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const float mean = sum[0] / ne00;
-
- // recenter and VARIANCE
- threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * y = dst + tgpig*ne00;
- sum[tpitg] = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- y[i00] = x[i00] - mean;
- sum[tpitg] += y[i00] * y[i00];
- }
-
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg/2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const float variance = sum[0] / ne00;
-
- const float scale = 1.0f/sqrt(variance + eps);
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- y[i00] = y[i00] * scale;
- }
-}
-
-kernel void kernel_rms_norm(
- device const void * src0,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant float & eps,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-
- float4 sumf = 0;
- float all_sum = 0;
-
- // parallel sum
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- sumf += x[i00] * x[i00];
- }
- all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
- all_sum = simd_sum(all_sum);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = all_sum;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- all_sum = buf[tiisg];
- all_sum = simd_sum(all_sum);
- }
-
- const float mean = all_sum/ne00;
- const float scale = 1.0f/sqrt(mean + eps);
-
- device float4 * y = (device float4 *) (dst + tgpig*ne00);
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- y[i00] = x[i00] * scale;
- }
-}
-
-kernel void kernel_group_norm(
- device const float * src0,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int32_t & n_groups,
- constant float & eps,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t ne = ne00*ne01*ne02;
- const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
-
- int start = tgpig * gs;
- int end = start + gs;
-
- start += tpitg;
-
- if (end >= ne) {
- end = ne;
- }
-
- float tmp = 0.0f; // partial sum for thread in warp
-
- for (int j = start; j < end; j += ntg) {
- tmp += src0[j];
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
- tmp = simd_sum(tmp);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = tmp;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- tmp = buf[tiisg];
- tmp = simd_sum(tmp);
- }
-
- const float mean = tmp / gs;
- tmp = 0.0f;
-
- for (int j = start; j < end; j += ntg) {
- float xi = src0[j] - mean;
- dst[j] = xi;
- tmp += xi * xi;
- }
-
- tmp = simd_sum(tmp);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (tiisg == 0) {
- buf[sgitg] = tmp;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- tmp = buf[tiisg];
- tmp = simd_sum(tmp);
- }
-
- const float variance = tmp / gs;
- const float scale = 1.0f/sqrt(variance + eps);
- for (int j = start; j < end; j += ntg) {
- dst[j] *= scale;
- }
-}
-
-// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q4 quants begin (0 or QK4_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
-
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
- device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
-
- for (int i = 0; i < 8; i += 2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
- acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
- }
-
- return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
-}
-
-// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q4 quants begin (0 or QK4_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float m = qb_curr->m;
-
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
- device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
-
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
- acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
- }
-
- return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
-}
-
-// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q5 quants begin (0 or QK5_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
-
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
-
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
- acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
- acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
- }
-
- return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
-}
-
-// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q5 quants begin (0 or QK5_1/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float m = qb_curr->m;
-
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
-
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
- acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
- acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
- }
-
- return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
-}
-
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-// quantizations where the block size is 32. It also does not
-// guard against the number of rows not being divisible by
-// N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
- const int nb = ne00/QK4_0;
-
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * nsg + sgitg) * nr;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- // pointers to src0 rows
- device const block_q_type * ax[nr];
- for (int row = 0; row < nr; ++row) {
- const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
- ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
- }
-
- float yl[16]; // src1 vector cache
- float sumf[nr] = {0.f};
-
- const int ix = (tiisg/2);
- const int il = (tiisg%2)*8;
-
- device const float * yb = y + ix * QK4_0 + il;
-
- // each thread in a SIMD group deals with half a block.
- for (int ib = ix; ib < nb; ib += nw/2) {
- float sumy[2] = { 0.f, 0.f };
-
-#pragma unroll
- for (int i = 0; i < 8; i += 2) {
- sumy[0] += yb[i + 0] + yb[i + 1];
- yl[i + 0] = yb[i + 0];
- yl[i + 1] = yb[i + 1]/256.f;
-
- sumy[1] += yb[i + 16] + yb[i + 17];
- yl[i + 8] = yb[i + 16]/16.f;
- yl[i + 9] = yb[i + 17]/4096.f;
- }
-
-#pragma unroll
- for (int row = 0; row < nr; row++) {
- sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
- }
-
- yb += QK4_0 * 16;
- }
-
- for (int row = 0; row < nr; ++row) {
- const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < ne01) {
- dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
- }
- }
-}
-
-kernel void kernel_mul_mv_q4_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q4_1_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q5_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q5_1_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-
-#define NB_Q8_0 8
-
-void kernel_mul_mv_q8_0_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
- const int nr = N_DST;
- const int nsg = N_SIMDGROUP;
- const int nw = N_SIMDWIDTH;
-
- const int nb = ne00/QK8_0;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * nsg + sgitg) * nr;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- // pointers to src0 rows
- device const block_q8_0 * ax[nr];
- for (int row = 0; row < nr; ++row) {
- const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
- ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
- }
-
- float yl[NB_Q8_0];
- float sumf[nr]={0.f};
-
- const int ix = tiisg/4;
- const int il = tiisg%4;
-
- device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
-
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
- for (int ib = ix; ib < nb; ib += nw/4) {
- for (int i = 0; i < NB_Q8_0; ++i) {
- yl[i] = yb[i];
- }
-
- for (int row = 0; row < nr; row++) {
- device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
- float sumq = 0.f;
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
- sumq += qs[iq] * yl[iq];
- }
- sumf[row] += sumq*ax[row][ib].d;
- }
-
- yb += NB_Q8_0 * nw;
- }
-
- for (int row = 0; row < nr; ++row) {
- const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < ne01) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_q8_0_f32")]]
-kernel void kernel_mul_mv_q8_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-#define N_MV_T_T 4
-
-template<typename T0, typename T04, typename T1, typename T14>
-void kernel_mul_mv_impl(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg) {
- const int64_t r0 = tgpig.x;
- const int64_t rb = tgpig.y*N_MV_T_T;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
- device const T0 * x = (device const T0 *) (src0 + offset0);
-
- if (ne00 < 128) {
- for (int row = 0; row < N_MV_T_T; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const T1 * y = (device const T1 *) (src1 + offset1);
-
- float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (T0) x[i] * (T1) y[i];
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- } else {
- device const T04 * x4 = (device const T04 *) x;
- for (int row = 0; row < N_MV_T_T; ++row) {
- int r1 = rb + row;
- if (r1 >= ne11) {
- break;
- }
-
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const T1 * y = (device const T1 *) (src1 + offset1);
- device const T14 * y4 = (device const T14 *) y;
-
- float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
- }
-}
-
-template<typename T0, typename T04, typename T1, typename T14>
-kernel void kernel_mul_mv(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_impl<T0, T04, T1, T14>(
- src0,
- src1,
- dst,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- nb03,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- nb13,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
-
-typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
-
-template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
-template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
-template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
-#endif
-
-template<typename T, typename T4>
-kernel void kernel_mul_mv_1row(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const T * x = (device const T *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
-
- float sumf = 0;
- if (ne00 < 128) {
- for (int i = tiisg; i < ne00; i += 32) {
- sumf += (float) x[i] * (float) y[i];
- }
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- } else {
- device const T4 * x4 = (device const T4 *) x;
- device const float4 * y4 = (device const float4 *) y;
-
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
- }
-
- float all_sum = simd_sum(sumf);
-
- if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
-}
-
-typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
-#endif
-
-// Assumes row size (ne00) is a multiple of 4
-template<typename T, typename T4>
-kernel void kernel_mul_mv_l4(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
-
- const int nrows = ne11;
- const int64_t r0 = tgpig.x;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
- device const T4 * x4 = (device const T4 *) (src0 + offset0);
-
- for (int r1 = 0; r1 < nrows; ++r1) {
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const float4 * y4 = (device const float4 *) (src1 + offset1);
-
- float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
- }
-
- float all_sum = simd_sum(sumf);
- if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
- }
- }
-}
-
-typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
-#endif
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
- thread float * cos_theta, thread float * sin_theta) {
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = freq_scale * theta_extrap;
- float theta = theta_interp;
- if (ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
- }
- *cos_theta = cos(theta) * mscale;
- *sin_theta = sin(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
- return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
-}
-
-static void rope_yarn_corr_dims(
- int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
-) {
- // start and end correction dims
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
-}
-
-template<typename T>
-kernel void kernel_rope_norm(
- device const void * src0,
- device const int32_t * src1,
- device const float * src2,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int & n_past,
- constant int & n_dims,
- constant int & n_ctx_orig,
- constant float & freq_base,
- constant float & freq_scale,
- constant float & ext_factor,
- constant float & attn_factor,
- constant float & beta_fast,
- constant float & beta_slow,
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int64_t i3 = tgpig[2];
- const int64_t i2 = tgpig[1];
- const int64_t i1 = tgpig[0];
-
- float corr_dims[2];
- rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
- device const int32_t * pos = src1;
-
- const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/n_dims;
-
- float cos_theta;
- float sin_theta;
-
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
- if (i0 < n_dims) {
- const int64_t ic = i0/2;
-
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
- const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
-
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- const float x0 = src[0];
- const float x1 = src[1];
-
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[1] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
-}
-
-template<typename T>
-kernel void kernel_rope_neox(
- device const void * src0,
- device const int32_t * src1,
- device const float * src2,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int & n_past,
- constant int & n_dims,
- constant int & n_ctx_orig,
- constant float & freq_base,
- constant float & freq_scale,
- constant float & ext_factor,
- constant float & attn_factor,
- constant float & beta_fast,
- constant float & beta_slow,
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int64_t i3 = tgpig[2];
- const int64_t i2 = tgpig[1];
- const int64_t i1 = tgpig[0];
-
- float corr_dims[2];
- rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
- device const int32_t * pos = src1;
-
- const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/n_dims;
-
- float cos_theta;
- float sin_theta;
-
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
- if (i0 < n_dims) {
- const int64_t ic = i0/2;
-
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
- const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
-
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
-
- const float x0 = src[0];
- const float x1 = src[n_dims/2];
-
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
-}
-
-typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
-typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
-
-template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
-template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
-
-template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
-template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
-
-typedef void (im2col_t)(
- device const float * x,
- device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
-
-template <typename T>
-kernel void kernel_im2col(
- device const float * x,
- device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
- const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
-
- const int32_t offset_dst =
- (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
- (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
-
- device T * pdst = (device T *) (dst);
-
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- pdst[offset_dst] = 0.0f;
- } else {
- const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
- pdst[offset_dst] = x[offset_src + iih * IW + iiw];
- }
-}
-
-template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
-template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
-
-typedef void (im2col_ext_t)(
- device const float * x,
- device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- constant int32_t & N,
- constant int32_t & KH,
- constant int32_t & KW,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
-
-template <typename T>
-kernel void kernel_im2col_ext(
- device const float * x,
- device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- constant int32_t & N,
- constant int32_t & KH,
- constant int32_t & KW,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
- const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
-
- const int32_t d = tgpig[0] / CHW;
- const int32_t chw = tgpig[0] % CHW;
- const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
- const int32_t HW = tgpig[0] % KHW;
-
- const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
- if (tpitg_0 >= N) {
- return;
- }
-
- const int32_t tpitg_1 = HW / KW;
- const int32_t tpitg_2 = HW % KW;
-
- const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
- const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
-
- const int32_t offset_dst =
- (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
- (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
-
- device T * pdst = (device T *) (dst);
-
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- pdst[offset_dst] = 0.0f;
- } else {
- const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
- pdst[offset_dst] = x[offset_src + iih * IW + iiw];
- }
-}
-
-template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
-template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
-
-kernel void kernel_upscale_f32(
- device const char * src0,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant float & sf0,
- constant float & sf1,
- constant float & sf2,
- constant float & sf3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
-
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
-
- const int64_t i03 = i3/sf3;
- const int64_t i02 = i2/sf2;
- const int64_t i01 = i1/sf1;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int64_t i00 = i0/sf0;
-
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- dst_ptr[0] = src0_ptr[0];
- }
-}
-
-kernel void kernel_pad_f32(
- device const char * src0,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
-
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
-
- const int64_t i03 = i3;
- const int64_t i02 = i2;
- const int64_t i01 = i1;
-
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
-
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i0 < ne00) {
- dst_ptr[i0] = src0_ptr[i0];
- } else {
- dst_ptr[i0] = 0.0f;
- }
- }
-
- return;
- }
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- dst_ptr[i0] = 0.0f;
- }
-}
-
-kernel void kernel_arange_f32(
- device char * dst,
- constant int64_t & ne0,
- constant float & start,
- constant float & step,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
-
- device float * dst_ptr = (device float *) dst;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- dst_ptr[i0] = start + step * i0;
- }
-}
-
-kernel void kernel_timestep_embedding_f32(
- device const char * src0,
- device char * dst,
- constant uint64_t & nb1,
- constant int & dim,
- constant int & max_period,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
-
- int i = tgpig.x;
- device float * embed_data = (device float *)(dst + i*nb1);
-
- int half_ = dim / 2;
- for (int j = tpitg.x; j < half_; j += ntg.x) {
- float timestep = ((device float *)src0)[i];
- float freq = (float)exp(-log((float)max_period) * j / half_);
- float arg = timestep * freq;
- embed_data[j ] = cos(arg);
- embed_data[j + half_] = sin(arg);
- }
-
- if (dim % 2 != 0 && tpitg.x == 0) {
- embed_data[dim] = 0.f;
- }
-}
-
-// bitonic sort implementation following the CUDA kernels as reference
-typedef void (argsort_t)(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- constant int64_t & ncols_pad,
- threadgroup int32_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]);
-
-template<ggml_sort_order order>
-kernel void kernel_argsort_f32_i32(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- constant int64_t & ncols_pad,
- threadgroup int32_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]) {
- // bitonic sort
- int col = tpitg[0];
- int row = tgpig[1];
-
- if (col >= ncols_pad) return;
-
- device const float * x_row = x + row * ncols;
- threadgroup int32_t * dst_row = shared_values;
-
- // initialize indices
- dst_row[col] = col;
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- for (int k = 2; k <= ncols_pad; k *= 2) {
- for (int j = k / 2; j > 0; j /= 2) {
- int ixj = col ^ j;
- if (ixj > col) {
- if ((col & k) == 0) {
- if (dst_row[col] >= ncols ||
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
- ) {
- SWAP(dst_row[col], dst_row[ixj]);
- }
- } else {
- if (dst_row[ixj] >= ncols ||
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
- ) {
- SWAP(dst_row[col], dst_row[ixj]);
- }
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- }
-
- // copy the result to dst without the padding
- if (col < ncols) {
- dst[row * ncols + col] = dst_row[col];
- }
-}
-
-template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
-
-kernel void kernel_leaky_relu_f32(
- device const float * src0,
- device float * dst,
- constant float & slope,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
-}
-
-// ref: https://arxiv.org/pdf/2307.08691.pdf
-template<
- typename q_t, // query types in shared memory
- typename q4_t,
- typename q8x8_t,
- typename k_t, // key types in shared memory
- typename k4x4_t,
- typename k8x8_t,
- typename v_t, // value types in shared memory
- typename v4x4_t,
- typename v8x8_t,
- typename qk_t, // Q*K types
- typename qk8x8_t,
- typename s_t, // soft-max types
- typename s8x8_t,
- typename o_t, // attention accumulation types
- typename o4_t,
- typename o8x8_t,
- typename kd4x4_t, // key type in device memory
- short nl_k,
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
- typename vd4x4_t, // key type in device memory
- short nl_v,
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short D, // head size
- short Q = 8, // queries per threadgroup
- short KV = 8, // key/value processed per each simdgroup
- short C = 32> // cache items per threadgroup
-kernel void kernel_flash_attn_ext(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int32_t & ne01,
- constant int32_t & ne02,
- constant int32_t & ne03,
- constant uint32_t & nb01,
- constant uint32_t & nb02,
- constant uint32_t & nb03,
- constant int32_t & ne11,
- constant int32_t & ne_12_2, // assume K and V are same shape
- constant int32_t & ne_12_3,
- constant uint32_t & nb_12_1,
- constant uint32_t & nb_12_2,
- constant uint32_t & nb_12_3,
- constant uint32_t & nb31,
- constant int32_t & ne1,
- constant int32_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint16_t & n_head_log2,
- constant float & logit_softcap,
- threadgroup half * shared [[threadgroup(0)]],
- ushort3 tgpig[[threadgroup_position_in_grid]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short nsg = ntg.y; // number of simdgroups
-
- const int iq3 = tgpig[2];
- const int iq2 = tgpig[1];
- const int iq1 = tgpig[0]*Q;
-
- const short D4 = D/4;
- const short D8 = D/8;
- const short D16 = D/16;
- const short NW = N_SIMDWIDTH;
- const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
-
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
- const short T = D + 2*TS; // shared memory size per query in (half)
-
- threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
- threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
-
- threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
- threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
-
- threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
- threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
-
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- o8x8_t lo[D8];
-
- // load heads from Q to shared memory
- for (short j = sgitg; j < Q; j += nsg) {
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
-
- for (short i = tiisg; i < D4; i += NW) {
- if (iq1 + j < ne01) {
- sq4[j*D4 + i] = (q4_t) q4[i];
- } else {
- sq4[j*D4 + i] = (q4_t) 0.0f;
- }
- }
- }
-
- // zero out lo
- for (short i = 0; i < D8; ++i) {
- lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
- }
-
- // zero out shared memory SH
- for (short j = 0; j < Q; ++j) {
- for (short i = tiisg; i < SH; i += NW) {
- ss[j*TS + i] = 0.0f;
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- {
- half S[Q] = { [0 ... Q-1] = 0.0f };
- half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
-
- // thread indices inside the simdgroup
- // TODO: see if we can utilize quad-group functions for better performance
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
- const short tx = tiisg%4;
- const short ty = tiisg/4;
-
- // broadcast kv
- //const short rk2 = ne02/ne12;
- //const short rk3 = ne03/ne13;
-
- const short ikv2 = iq2/(ne02/ne_12_2);
- const short ikv3 = iq3/(ne03/ne_12_3);
-
- // load the queries from shared memory into local memory
- q8x8_t mq[D8];
-
- for (short i = 0; i < D8; ++i) {
- simdgroup_load(mq[i], sq + i*8, D);
- }
-
- const bool has_mask = mask != q;
-
- half slope = 1.0f;
-
- // ALiBi
- if (max_bias > 0.0f) {
- const short h = iq2;
-
- const half base = h < n_head_log2 ? m0 : m1;
- const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exph);
- }
-
- // loop over the KV cache
- // each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
- const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
- break;
- }
-
- if (has_mask) {
- // used to detect blocks full of -INF
- half smax = -INFINITY;
-
- // load the mask in shared memory
- #pragma unroll(Q)
- for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
-
- const half m = pm[ic + tiisg];
-
- ss[j*TS + C + tiisg] = m;
- smax = max(smax, m);
- }
-
- smax = simd_max(smax);
-
- if (smax == -INFINITY) {
- continue;
- }
- }
-
- // Q*K^T
- {
- for (short cc = 0; cc < C/8; ++cc) {
- qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
-
- // this is compile-time check, so it does not have runtime overhead
- if (is_same<kd4x4_t, k4x4_t>::value) {
- // we can read directly from global memory
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- #pragma unroll(D8)
- for (short i = 0; i < D8; ++i) {
- k8x8_t mk;
- simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
-
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
- }
- } else {
- for (short ii = 0; ii < D16; ii += 4) {
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- if (D16%4 == 0) {
- // the head is evenly divisible by 4*16 = 64, so no need for bound checks
- {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- k8x8_t mk;
-
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
-
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
- }
- } else {
- if (ii + tx < D16) {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
- k8x8_t mk;
-
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
-
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
- }
- }
- }
- }
-
- // cast qk_t -> s_t
- //s8x8_t mqks(1.0f);
- //simdgroup_multiply(mqks, mqk, mqks);
- //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
-
- simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
- }
- }
-
- // online softmax
- {
- for (ushort j = 0; j < Q; ++j) {
- const half m = M[j];
-
- // scale and apply the logitcap / mask
- half s = ss[j*TS + tiisg]*scale;
-
- if (logit_softcap != 0.0f) {
- s = logit_softcap*precise::tanh(s);
- }
-
- // mqk = mqk + mask*slope
- s += slope*ss[j*TS + C + tiisg];
-
- M[j] = simd_max(max(M[j], s));
-
- const half ms = exp(m - M[j]);
- const half vs = exp(s - M[j]);
-
- S[j] = S[j]*ms + simd_sum(vs);
-
- // the P matrix from the paper (Q rows, C columns)
- ss[j*TS + tiisg] = vs;
-
- // create a QxQ diagonal matrix for rescaling the output
- if (tiisg == j) {
- ss[j*TS + 2*C + j] = ms;
- }
- }
- }
-
- // O = diag(ms)*O
- {
- s8x8_t mm;
- simdgroup_load(mm, ss + 2*C, TS, 0, false);
-
- #pragma unroll(D8)
- for (short i = 0; i < D8; ++i) {
- simdgroup_multiply(lo[i], mm, lo[i]);
- }
- }
-
- // O = O + (Q*K^T)*V
- {
- for (short cc = 0; cc < C/8; ++cc) {
- s8x8_t ms;
- simdgroup_load(ms, ss + 8*cc, TS, 0, false);
-
- if (is_same<vd4x4_t, v4x4_t>::value) {
- // we can read directly from global memory
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- #pragma unroll(D8)
- for (short i = 0; i < D8; ++i) {
- v8x8_t mv;
- simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
-
- simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
- }
- } else {
- for (short ii = 0; ii < D16; ii += 4) {
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- if (D16%4 == 0) {
- // no need for bound checks
- {
- v4x4_t tmp;
- deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
- sv4x4[4*ty + tx] = tmp;
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- v8x8_t mv;
-
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
-
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
- }
- } else {
- if (ii + tx < D16) {
- v4x4_t tmp;
- deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
- sv4x4[4*ty + tx] = tmp;
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
- v8x8_t mv;
-
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
-
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
- }
- }
- }
- }
- }
- }
- }
-
- // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- for (short j = 0; j < Q; ++j) {
- if (tiisg == 0) {
- ss[j*TS + 0] = S[j];
- ss[j*TS + 1] = M[j];
- }
- }
- }
-
- // reduce the warps sequentially
- for (ushort sg = 1; sg < nsg; ++sg) {
- half S = { 0.0f };
- half M = { -__FLT16_MAX__/2 };
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // each simdgroup stores its output to shared memory, reusing sq
- if (sgitg == sg) {
- for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], so + i*8, D, 0, false);
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // the first simdgroup accumulates the results from the other simdgroups
- if (sgitg == 0) {
- for (short j = 0; j < Q; ++j) {
- const half S0 = ss[j*TS + 0];
- const half S1 = ss[j*TS + sg*SH + 0];
-
- const half M0 = ss[j*TS + 1];
- const half M1 = ss[j*TS + sg*SH + 1];
-
- M = max(M0, M1);
-
- const half ms0 = exp(M0 - M);
- const half ms1 = exp(M1 - M);
-
- S = S0*ms0 + S1*ms1;
-
- if (tiisg == 0) {
- ss[j*TS + 0] = S;
- ss[j*TS + 1] = M;
-
- ss[j*TS + 2*C + j ] = ms0;
- ss[j*TS + 2*C + j + sg*SH] = ms1;
- }
- }
-
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- {
- s8x8_t ms0;
- s8x8_t ms1;
-
- simdgroup_load(ms0, ss + 2*C, TS, 0, false);
- simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
-
- #pragma unroll(D8)
- for (short i = 0; i < D8; ++i) {
- o8x8_t t;
-
- simdgroup_load (t, so + i*8, D, 0, false);
- simdgroup_multiply(t, ms1, t);
-
- simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
- }
- }
- }
- }
-
- // store result to shared memory (reuse sq)
- if (sgitg == 0) {
- for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], so + i*8, D, 0, false);
- }
- }
-
- device float4 * dst4 = (device float4 *) dst;
-
- // final rescale with 1/S and store to global memory
- if (sgitg == 0) {
- for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
- const float S = ss[j*TS + 0];
-
- for (short i = tiisg; i < D4; i += NW) {
- dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
- }
- }
- }
-}
-
-// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
-// template to be able to explore different combinations
-//
-#define FA_TYPES \
- half, half4, simdgroup_half8x8, \
- half, half4x4, simdgroup_half8x8, \
- half, half4x4, simdgroup_half8x8, \
- float, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
- half, half4, simdgroup_half8x8
-
-typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
-
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
-
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
-#endif
-
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
-
-#undef FA_TYPES
-
-template<
- typename q4_t, // query types in shared memory
- typename q4x4_t,
- typename k4x4_t, // key types in shared memory
- typename v4x4_t, // value types in shared memory
- typename qk_t, // Q*K types
- typename s_t, // soft-max types
- typename s4_t,
- typename s4x4_t,
- typename o4x4_t, // attention accumulation types
- typename kd4x4_t, // key type in device memory
- short nl_k,
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
- typename vd4x4_t, // key type in device memory
- short nl_v,
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short D, // head size
- short Q = 1, // queries per threadgroup
- short C = 32> // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int32_t & ne01,
- constant int32_t & ne02,
- constant int32_t & ne03,
- constant uint32_t & nb01,
- constant uint32_t & nb02,
- constant uint32_t & nb03,
- constant int32_t & ne11,
- constant int32_t & ne_12_2, // assume K and V are same shape
- constant int32_t & ne_12_3,
- constant uint32_t & nb_12_1,
- constant uint32_t & nb_12_2,
- constant uint32_t & nb_12_3,
- constant uint32_t & nb31,
- constant int32_t & ne1,
- constant int32_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint16_t & n_head_log2,
- constant float & logit_softcap,
- threadgroup half * shared [[threadgroup(0)]],
- ushort3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short nsg = ntg.y; // number of simdgroups
-
- const int iq3 = tgpig[2];
- const int iq2 = tgpig[1];
- const int iq1 = tgpig[0];
-
- const short D4 = D/4;
- const short D16 = D/16;
- const short NW = N_SIMDWIDTH;
- const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
- const short SH = 2*C; // shared memory per simdgroup
-
- const short T = D + nsg*SH; // shared memory size per query in (half)
-
- //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
- threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
-
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- o4x4_t lo[D16/NL];
-
- // load heads from Q to shared memory
- device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
-
- for (short i = tiisg; i < D4; i += NW) {
- if (iq1 < ne01) {
- sq4[i] = (q4_t) q4[i];
- } else {
- sq4[i] = (q4_t) 0.0f;
- }
- }
-
- // zero out lo
- for (short i = 0; i < D16/NL; ++i) {
- lo[i] = (o4x4_t) 0.0f;
- }
-
- // zero out shared memory SH
- for (short i = tiisg; i < SH/4; i += NW) {
- ss4[i] = (s4_t) 0.0f;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- {
- half S = 0.0f;
- half M = -__FLT16_MAX__/2;
-
- // thread indices inside the simdgroup
- const short tx = tiisg%NL;
- const short ty = tiisg/NL;
-
- // broadcast kv
- //const short rk2 = ne02/ne12;
- //const short rk3 = ne03/ne13;
-
- const short ikv2 = iq2/(ne02/ne_12_2);
- const short ikv3 = iq3/(ne03/ne_12_3);
-
- // load the queries from shared memory into local memory
- q4x4_t mq[D16/NL];
-
- #pragma unroll(D16/NL)
- for (short ii = 0; ii < D16; ii += NL) {
- mq[ii/NL] = sq4x4[ii + tx];
- }
-
- const bool has_mask = mask != q;
-
- // pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*nb31);
-
- half slope = 1.0f;
-
- // ALiBi
- if (max_bias > 0.0f) {
- const short h = iq2;
-
- const half base = h < n_head_log2 ? m0 : m1;
- const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exph);
- }
-
- // loop over the KV cache
- // each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
- const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
- break;
- }
-
- if (has_mask) {
- sm[tiisg] = pm[ic + tiisg];
- }
-
- // Q*K^T
- {
- // each simdgroup processes 1 query and 4 (NW/NL) keys
- for (short cc = 0; cc < C/4; ++cc) {
- qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
-
- device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- #pragma unroll(D16/NL)
- for (short ii = 0; ii < D16; ii += NL) {
- const short i = ii + tx;
-
- k4x4_t mk;
- deq_k(pk + i/nl_k, i%nl_k, mk);
-
- // note: this is less precise than the version below
- //mqka[0] += dot(mq[ii/NL][0], mk[0]);
- //mqka[1] += dot(mq[ii/NL][1], mk[1]);
- //mqka[2] += dot(mq[ii/NL][2], mk[2]);
- //mqka[3] += dot(mq[ii/NL][3], mk[3]);
-
- mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
- mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
- mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
- mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
- }
-
- qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
-
- // simdgroup reduce
- // [ 0 .. 7] -> [ 0]
- // [ 8 .. 15] -> [ 8]
- // [16 .. 23] -> [16]
- // [24 .. 31] -> [24]
- //mqk += simd_shuffle_down(mqk, 16);
- //mqk += simd_shuffle_down(mqk, 8);
- mqk += simd_shuffle_down(mqk, 4);
- mqk += simd_shuffle_down(mqk, 2);
- mqk += simd_shuffle_down(mqk, 1);
-
- // mqk = mqk*scale + mask*slope
- if (tx == 0) {
- mqk *= scale;
-
- if (logit_softcap != 0.0f) {
- mqk = logit_softcap*precise::tanh(mqk);
- }
-
- mqk += sm[4*cc + ty]*slope;
-
- ss[4*cc + ty] = mqk;
- }
- }
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- // online softmax
- {
- const half m = M;
- const half s = ss[tiisg];
-
- M = simd_max(max(M, s));
-
- const half ms = exp(m - M);
- const half vs = exp(s - M);
-
- S = S*ms + simd_sum(vs);
-
- // the P matrix from the paper (Q rows, C columns)
- ss[tiisg] = vs;
-
- // O = diag(ms)*O
- #pragma unroll(D16/NL)
- for (short ii = 0; ii < D16; ii += NL) {
- lo[ii/NL] *= ms;
- }
- }
-
- simdgroup_barrier(mem_flags::mem_threadgroup);
-
- // O = O + (Q*K^T)*V
- {
- for (short cc = 0; cc < C/4; ++cc) {
- device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
- const s4x4_t ms(ss[4*cc + ty]);
-
- #pragma unroll(D16/NL)
- for (short ii = 0; ii < D16; ii += NL) {
- const short i = ii + tx;
-
- v4x4_t mv;
- deq_v(pv4 + i/nl_v, i%nl_v, mv);
-
- lo[ii/NL] += mv*ms;
- }
- }
- }
- }
-
- // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- if (tiisg == 0) {
- ss[0] = (s_t) S;
- ss[1] = (s_t) M;
- }
- }
-
- // simdgroup reduce
- // [ 0, 8, 16, 24] -> [ 0]
- // [ 1, 9, 17, 25] -> [ 1]
- // [ 2, 10, 18, 26] -> [ 2]
- // [ 3, 11, 19, 27] -> [ 3]
- // [ 4, 12, 20, 28] -> [ 4]
- // [ 5, 13, 21, 29] -> [ 5]
- // [ 6, 14, 22, 30] -> [ 6]
- // [ 7, 15, 23, 31] -> [ 7]
- for (short ii = 0; ii < D16; ii += NL) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
-
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
-
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
-
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // store results to shared memory
- for (short i = tiisg; i < D16; i += NL) {
- sr4x4[i] = lo[i/NL];
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // parallel reduce
- for (short r = nsg/2; r > 0; r >>= 1) {
- if (sgitg < r) {
- const half S0 = ss[ 0];
- const half S1 = ss[r*SH + 0];
-
- const half M0 = ss[ 1];
- const half M1 = ss[r*SH + 1];
-
- const half M = max(M0, M1);
-
- const half ms0 = exp(M0 - M);
- const half ms1 = exp(M1 - M);
-
- const half S = S0*ms0 + S1*ms1;
-
- if (tiisg == 0) {
- ss[0] = S;
- ss[1] = M;
- }
-
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- for (short i = tiisg; i < D16; i += NW) {
- sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- device float4x4 * dst44 = (device float4x4 *) dst;
-
- // final rescale with 1/S and store to global memory
- if (sgitg == 0) {
- const float S = ss[0];
-
- for (short i = tiisg; i < D16; i += NW) {
- dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
- }
- }
-}
-
-// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
-// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
-//
-#define FA_TYPES \
- half4, half4x4, \
- half4x4, \
- half4x4, \
- float, \
- half, half4, half4x4, \
- half4x4
-
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
-
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
-#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
-
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
-#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
-
-#undef FA_TYPES
-
-template<typename T0, typename T1>
-kernel void kernel_cpy(
- device const void * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
- device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- dst_data[i00] = (T1) src[0];
- }
-}
-
-typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
-
-template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
-template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
-#endif
-template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
-template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
-template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
-#endif
-
-kernel void kernel_cpy_f32_q8_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
-
- device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float amax = 0.0f; // absolute max
-
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
- }
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK8_0].d = d;
-
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
-
- dst_data[i00/QK8_0].qs[j] = round(x0);
- }
- }
-}
-
-kernel void kernel_cpy_f32_q4_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
-
- device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK4_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / -8;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK4_0].d = d;
-
- for (int j = 0; j < QK4_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_0/2 + j]*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
- dst_data[i00/QK4_0].qs[j] = xi0;
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
- }
- }
-}
-
-kernel void kernel_cpy_f32_q4_1(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
-
- device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float min = FLT_MAX;
- float max = -FLT_MAX;
-
- for (int j = 0; j < QK4_1; j++) {
- const float v = src[j];
- if (min > v) min = v;
- if (max < v) max = v;
- }
-
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK4_1].d = d;
- dst_data[i00/QK4_1].m = min;
-
- for (int j = 0; j < QK4_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK4_1/2 + j] - min)*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
- dst_data[i00/QK4_1].qs[j] = xi0;
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
- }
- }
-}
-
-kernel void kernel_cpy_f32_q5_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
-
- device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK5_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / -16;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK5_0].d = d;
-
- uint32_t qh = 0;
- for (int j = 0; j < QK5_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK5_0/2 + j]*id;
-
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_0].qh[j] = qh8[j];
- }
- }
-}
-
-kernel void kernel_cpy_f32_q5_1(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
-
- device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float max = src[0];
- float min = src[0];
-
- for (int j = 1; j < QK5_1; j++) {
- const float v = src[j];
- min = v < min ? v : min;
- max = v > max ? v : max;
- }
-
- const float d = (max - min) / 31;
- const float id = d ? 1.0f/d : 0.0f;
-
- dst_data[i00/QK5_1].d = d;
- dst_data[i00/QK5_1].m = min;
-
- uint32_t qh = 0;
- for (int j = 0; j < QK5_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK5_1/2 + j] - min)*id;
-
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst_data[i00/QK5_1].qh[j] = qh8[j];
- }
- }
-}
-
-static inline int best_index_int8(int n, constant float * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
- }
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
-kernel void kernel_cpy_f32_iq4_nl(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
-
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
-
- device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < QK4_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
-
- const float d = max / kvalues_iq4nl_f[0];
- const float id = d ? 1.0f/d : 0.0f;
-
- float sumqx = 0, sumq2 = 0;
- for (int j = 0; j < QK4_NL/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_NL/2 + j]*id;
-
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
-
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
-
- const float v0 = kvalues_iq4nl_f[xi0];
- const float v1 = kvalues_iq4nl_f[xi1];
- const float w0 = src[0 + j]*src[0 + j];
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
- sumq2 += w0*v0*v0 + w1*v1*v1;
-
- }
-
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
-
- }
-}
-
-kernel void kernel_concat(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int32_t & dim,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
-
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
-
- int64_t o[4] = {0, 0, 0, 0};
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
-
- device const float * x;
-
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
- x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
- } else {
- x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
- }
-
- device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
- *y = *x;
- }
-}
-
-void kernel_mul_mv_q2_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int ix = tiisg/8; // 0...3
- const int it = tiisg%8; // 0...7
- const int iq = it/4; // 0 or 1
- const int ir = it%4; // 0...3
- const int is = (8*ir)/16;// 0 or 1
-
- device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
-
- for (int ib = ix; ib < nb; ib += 4) {
-
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; ++i) {
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
- yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
- yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
- yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
- }
-
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
- device const half * dh = &x[ib].d;
-
- for (int row = 0; row < N_DST; row++) {
-
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; i += 2) {
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
- }
- float dall = dh[0];
- float dmin = dh[1] * 1.f/16.f;
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
- dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
-
- qs += nb01/2;
- sc += nb01;
- dh += nb01/2;
- }
-
- y4 += 4 * QK_K;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_q2_K_f32")]]
-kernel void kernel_mul_mv_q2_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q3_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int64_t im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
-
- //const uint16_t kmask1 = 0x3030;
- //const uint16_t kmask2 = 0x0f0f;
-
- const int tid = tiisg/4;
- const int ix = tiisg%4;
- const int ip = tid/4; // 0 or 1
- const int il = 2*((tid%4)/2); // 0 or 2
- const int ir = tid%2;
- const int n = 8;
- const int l0 = n*ir;
-
- // One would think that the Metal compiler would figure out that ip and il can only have
- // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
- // with these two tales.
- //
- // Possible masks for the high bit
- const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
- {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
- {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
- {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
-
- // Possible masks for the low 2 bits
- const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
-
- const ushort4 hm = mm[2*ip + il/2];
-
- const int shift = 2*il;
- const float v1 = il == 0 ? 4.f : 64.f;
- const float v2 = 4.f * v1;
-
- const uint16_t s_shift1 = 4*ip;
- const uint16_t s_shift2 = s_shift1 + il;
-
- const int q_offset = 32*ip + l0;
- const int y_offset = 128*ip + 32*il + l0;
-
- device const float * y1 = yy + ix*QK_K + y_offset;
-
- uint32_t scales32, aux32;
- thread uint16_t * scales16 = (thread uint16_t *)&scales32;
- thread const int8_t * scales = (thread const int8_t *)&scales32;
-
- float sumf1[2] = {0.f};
- float sumf2[2] = {0.f};
- for (int i = ix; i < nb; i += 4) {
- for (int l = 0; l < 8; ++l) {
- yl[l+ 0] = y1[l+ 0];
- yl[l+ 8] = y1[l+16];
- yl[l+16] = y1[l+32];
- yl[l+24] = y1[l+48];
- }
-
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
- device const uint16_t * a = (device const uint16_t *)(x[i].scales);
- device const half * dh = &x[i].d;
-
- for (int row = 0; row < 2; ++row) {
- const float d_all = (float)dh[0];
-
- scales16[0] = a[4];
- scales16[1] = a[5];
- aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
- scales16[0] = a[il+0];
- scales16[1] = a[il+1];
- scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
-
- float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
- for (int l = 0; l < n; l += 2) {
- const int32_t qs = q[l/2];
- s1 += yl[l+0] * (qs & qm[il/2][0]);
- s2 += yl[l+1] * (qs & qm[il/2][1]);
- s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
- s4 += yl[l+16] * (qs & qm[il/2][2]);
- s5 += yl[l+17] * (qs & qm[il/2][3]);
- s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
- }
- float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
- float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
- sumf1[row] += d1 * (scales[0] - 32);
- sumf2[row] += d2 * (scales[2] - 32);
-
- s1 = s2 = s3 = s4 = s5 = s6 = 0;
- for (int l = 0; l < n; l += 2) {
- const int32_t qs = q[l/2+8];
- s1 += yl[l+8] * (qs & qm[il/2][0]);
- s2 += yl[l+9] * (qs & qm[il/2][1]);
- s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
- s4 += yl[l+24] * (qs & qm[il/2][2]);
- s5 += yl[l+25] * (qs & qm[il/2][3]);
- s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
- }
- d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
- d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
- sumf1[row] += d1 * (scales[1] - 32);
- sumf2[row] += d2 * (scales[3] - 32);
-
- q += nb01/2;
- h += nb01/2;
- a += nb01/2;
- dh += nb01/2;
- }
-
- y1 += 4 * QK_K;
- }
-
- for (int row = 0; row < 2; ++row) {
- const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
- sumf1[row] = simd_sum(sumf);
- }
- if (tiisg == 0) {
- for (int row = 0; row < 2; ++row) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
- }
- }
-}
-
-[[host_name("kernel_mul_mv_q3_K_f32")]]
-kernel void kernel_mul_mv_q3_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q4_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
- const int ix = tiisg/8; // 0...3
- const int it = tiisg%8; // 0...7
- const int iq = it/4; // 0 or 1
- const int ir = it%4; // 0...3
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int first_row = r0 * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[16];
- float yh[16];
- float sumf[N_DST]={0.f}, all_sum;
-
- device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
-
- uint16_t sc16[4];
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
-
- for (int ib = ix; ib < nb; ib += 4) {
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; ++i) {
- yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
- yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
- yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
- yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
- }
-
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
- device const half * dh = &x[ib].d;
-
- for (int row = 0; row < N_DST; row++) {
- sc16[0] = sc[0] & kmask1;
- sc16[1] = sc[2] & kmask1;
- sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
- sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
-
- device const uint16_t * q2 = q1 + 32;
-
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; i += 2) {
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
- }
-
- float dall = dh[0];
- float dmin = dh[1];
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
- (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
- (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
- (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
-
- q1 += nb01/2;
- sc += nb01/2;
- dh += nb01/2;
- }
-
- y4 += 4 * QK_K;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_q4_K_f32")]]
-kernel void kernel_mul_mv_q4_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q5_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
-
- float sumf[2]={0.f};
-
- float yl[16], yh[16];
-
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
- const int tid = tiisg/4;
- const int ix = tiisg%4;
- const int iq = tid/4;
- const int ir = tid%4;
- const int n = 8;
-
- const int l0 = n*ir;
- const int q_offset = 32*iq + l0;
- const int y_offset = 64*iq + l0;
-
- const uint8_t hm1 = 1u << (2*iq);
- const uint8_t hm2 = hm1 << 1;
- const uint8_t hm3 = hm1 << 4;
- const uint8_t hm4 = hm2 << 4;
-
- uint16_t sc16[4];
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
-
- device const float * y1 = yy + ix*QK_K + y_offset;
-
- for (int i = ix; i < nb; i += 4) {
- device const uint8_t * q1 = x[i].qs + q_offset;
- device const uint8_t * qh = x[i].qh + l0;
- device const half * dh = &x[i].d;
- device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
-
- device const float * y2 = y1 + 128;
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (int l = 0; l < 8; ++l) {
- yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
- yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
- yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
- yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
- }
-
- for (int row = 0; row < 2; ++row) {
- device const uint8_t * q2 = q1 + 64;
-
- sc16[0] = a[0] & kmask1;
- sc16[1] = a[2] & kmask1;
- sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
- sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
-
- float4 acc1 = {0.f};
- float4 acc2 = {0.f};
- for (int l = 0; l < n; ++l) {
- uint8_t h = qh[l];
- acc1[0] += yl[l+0] * (q1[l] & 0x0F);
- acc1[1] += yl[l+8] * (q1[l] & 0xF0);
- acc1[2] += yh[l+0] * (q2[l] & 0x0F);
- acc1[3] += yh[l+8] * (q2[l] & 0xF0);
- acc2[0] += h & hm1 ? yl[l+0] : 0.f;
- acc2[1] += h & hm2 ? yl[l+8] : 0.f;
- acc2[2] += h & hm3 ? yh[l+0] : 0.f;
- acc2[3] += h & hm4 ? yh[l+8] : 0.f;
- }
- const float dall = dh[0];
- const float dmin = dh[1];
- sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
- sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
- sc8[4] * (acc1[2] + 16.f*acc2[2]) +
- sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
-
- q1 += nb01;
- qh += nb01;
- dh += nb01/2;
- a += nb01/2;
- }
-
- y1 += 4 * QK_K;
- }
-
- for (int row = 0; row < 2; ++row) {
- const float tot = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_q5_K_f32")]]
-kernel void kernel_mul_mv_q5_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q6_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const uint8_t kmask1 = 0x03;
- const uint8_t kmask2 = 0x0C;
- const uint8_t kmask3 = 0x30;
- const uint8_t kmask4 = 0xC0;
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int row = 2 * r0 + sgitg;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
-
- float sumf = 0;
-
- const int tid = tiisg/2;
- const int ix = tiisg%2;
- const int ip = tid/8; // 0 or 1
- const int il = tid%8;
- const int n = 4;
- const int l0 = n*il;
- const int is = 8*ip + l0/16;
-
- const int y_offset = 128*ip + l0;
- const int q_offset_l = 64*ip + l0;
- const int q_offset_h = 32*ip + l0;
-
- for (int i = ix; i < nb; i += 2) {
-
- device const uint8_t * q1 = x[i].ql + q_offset_l;
- device const uint8_t * q2 = q1 + 32;
- device const uint8_t * qh = x[i].qh + q_offset_h;
- device const int8_t * sc = x[i].scales + is;
-
- device const float * y = yy + i * QK_K + y_offset;
-
- const float dall = x[i].d;
-
- float4 sums = {0.f, 0.f, 0.f, 0.f};
- for (int l = 0; l < n; ++l) {
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
- }
-
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
-
- }
-
- const float tot = simd_sum(sumf);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
- }
-}
-
-[[host_name("kernel_mul_mv_q6_K_f32")]]
-kernel void kernel_mul_mv_q6_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-// ======================= "True" 2-bit
-
-void kernel_mul_mv_iq2_xxs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
- {
- int nval = 4;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq2_xxs * xr = x + ibl;
- device const uint16_t * q2 = xr->qs + 4 * ib;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- const float db = dh[0];
- device const uint8_t * aux8 = (device const uint8_t *)q2;
- const uint32_t aux32 = q2[2] | (q2[3] << 16);
- const float d = db * (0.5f + (aux32 >> 28));
-
- float sum = 0;
- for (int l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
- const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
- for (int j = 0; j < 8; ++j) {
- sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- sumf[row] += d * sum;
-
- dh += nb01/2;
- q2 += nb01/2;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_iq2_xxs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq2_xs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
- {
- int nval = 8;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq2_xs * xr = x + ibl;
- device const uint16_t * q2 = xr->qs + 4 * ib;
- device const uint8_t * sc = xr->scales + ib;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- const float db = dh[0];
- const uint8_t ls1 = sc[0] & 0xf;
- const uint8_t ls2 = sc[0] >> 4;
- const float d1 = db * (0.5f + ls1);
- const float d2 = db * (0.5f + ls2);
-
- float sum1 = 0, sum2 = 0;
- for (int l = 0; l < 2; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
- const uint8_t signs = shared_signs[(q2[l] >> 9)];
- for (int j = 0; j < 8; ++j) {
- sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- for (int l = 2; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
- const uint8_t signs = shared_signs[(q2[l] >> 9)];
- for (int j = 0; j < 8; ++j) {
- sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- sumf[row] += d1 * sum1 + d2 * sum2;
-
- dh += nb01/2;
- q2 += nb01/2;
- sc += nb01;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq2_xs_f32")]]
-kernel void kernel_mul_mv_iq2_xs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq3_xxs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
- {
- int nval = 4;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq3_xxs * xr = x + ibl;
- device const uint8_t * q3 = xr->qs + 8 * ib;
- device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- const float db = dh[0];
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
- const float d = db * (0.5f + (aux32 >> 28));
-
- float2 sum = {0};
- for (int l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
- const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
- for (int j = 0; j < 4; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
- sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
- }
- }
- sumf[row] += d * (sum[0] + sum[1]);
-
- dh += nb01/2;
- q3 += nb01;
- gas += nb01/2;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_iq3_xxs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq3_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
- {
- int nval = 8;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq3_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 8 * ib;
- device const uint8_t * qh = xr->qh + ib;
- device const uint8_t * sc = xr->scales + (ib/2);
- device const uint8_t * signs = xr->signs + 4 * ib;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- const float db = dh[0];
- const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
-
- float2 sum = {0};
- for (int l = 0; l < 4; ++l) {
- const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
- const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
- for (int j = 0; j < 4; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
- sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
- }
- }
- sumf[row] += d * (sum[0] + sum[1]);
-
- dh += nb01/2;
- qs += nb01;
- qh += nb01;
- sc += nb01;
- signs += nb01;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq3_s_f32")]]
-kernel void kernel_mul_mv_iq3_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq2_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
- //{
- // int nval = 32;
- // int pos = (32*sgitg + tiisg)*nval;
- // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
- // threadgroup_barrier(mem_flags::mem_threadgroup);
- //}
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq2_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint8_t * qh = xr->qh + ib;
- device const uint8_t * sc = xr->scales + ib;
- device const uint8_t * signs = qs + QK_K/8;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- const float db = dh[0];
- const float d1 = db * (0.5f + (sc[0] & 0xf));
- const float d2 = db * (0.5f + (sc[0] >> 4));
-
- float2 sum = {0};
- for (int l = 0; l < 2; ++l) {
- //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
- //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
- for (int j = 0; j < 8; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
- sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
- }
- }
- sumf[row] += d1 * sum[0] + d2 * sum[1];
-
- dh += nb01/2;
- qs += nb01;
- qh += nb01;
- sc += nb01;
- signs += nb01;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq2_s_f32")]]
-kernel void kernel_mul_mv_iq2_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq1_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_value,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- float sumy = 0;
- for (int i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- sumy += yl[i];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq1_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint16_t * qh = xr->qh + ib;
- device const half * dh = &xr->d;
-
- for (int row = 0; row < N_DST; row++) {
-
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
- constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
- constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
-
- float sum = 0;
- for (int j = 0; j < 4; ++j) {
- sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
- + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
- + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
- + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
- }
- sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
-
- dh += nb01/2;
- qs += nb01;
- qh += nb01/2;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-void kernel_mul_mv_iq1_m_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_value,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
-
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
-
- const int nb32 = nb * (QK_K / 32);
-
- const int ix = tiisg;
-
- device const float * y4 = y + 32 * ix;
-
- iq1m_scale_t scale;
-
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
- float4 sumy = {0.f};
- for (int i = 0; i < 8; ++i) {
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
- yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
- yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
- yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
- }
-
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
-
- device const block_iq1_m * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint8_t * qh = xr->qh + 2 * ib;
- device const uint16_t * sc = (device const uint16_t *)xr->scales;
-
- for (int row = 0; row < N_DST; row++) {
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
- constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
- constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
-
- float2 sum = {0.f};
- for (int j = 0; j < 4; ++j) {
- sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
- + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
- sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
- + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
- }
- const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-
- sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
- (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
-
- sc += nb01/2;
- qs += nb01;
- qh += nb01;
- }
-
- y4 += 32 * 32;
- }
-
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-void kernel_mul_mv_iq4_nl_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values_i8,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
- const int nb = ne00/QK4_NL;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * 2 + sgitg) * 2;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- const int ix = tiisg/2; // 0...15
- const int it = tiisg%2; // 0 or 1
-
- shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- float4 yl[4];
- float sumf[2]={0.f}, all_sum;
-
- device const float * yb = y + ix * QK4_NL + it * 8;
-
- uint32_t aux32[2];
- thread const uint8_t * q8 = (thread const uint8_t *)aux32;
-
- float4 qf1, qf2;
-
- for (int ib = ix; ib < nb; ib += 16) {
-
- device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
- for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
-
- device const block_iq4_nl & xb = x[row*nb + ib];
- device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
-
- float4 acc1 = {0.f}, acc2 = {0.f};
-
- aux32[0] = q4[0] | (q4[1] << 16);
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
- aux32[0] &= 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
- acc1 += yl[0] * qf1;
- acc2 += yl[1] * qf2;
-
- aux32[0] = q4[2] | (q4[3] << 16);
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
- aux32[0] &= 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
- acc1 += yl[2] * qf1;
- acc2 += yl[3] * qf2;
-
- acc1 += acc2;
-
- sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
- }
-
- yb += 16 * QK4_NL;
- }
-
- for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-void kernel_mul_mv_iq4_xs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values_i8,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
- const int nb = ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * 2 + sgitg) * 2;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
-
- device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
-
- const int ix = tiisg/16; // 0 or 1
- const int it = tiisg%16; // 0...15
- const int ib = it/2;
- const int il = it%2;
-
- shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- float4 yl[4];
- float sumf[2]={0.f}, all_sum;
-
- device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
-
- uint32_t aux32[2];
- thread const uint8_t * q8 = (thread const uint8_t *)aux32;
-
- float4 qf1, qf2;
-
- for (int ibl = ix; ibl < nb; ibl += 2) {
-
- device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
- for (int row = 0; row < 2; ++row) {
-
- device const block_iq4_xs & xb = x[row*nb + ibl];
- device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
-
- float4 acc1 = {0.f}, acc2 = {0.f};
-
- aux32[0] = q4[0] & 0x0f0f0f0f;
- aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
- acc1 += yl[0] * qf1;
- acc2 += yl[1] * qf2;
-
- aux32[0] = q4[1] & 0x0f0f0f0f;
- aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
- acc1 += yl[2] * qf1;
- acc2 += yl[3] * qf2;
-
- acc1 += acc2;
-
- const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
- sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
- }
-
- yb += 2 * QK_K;
- }
-
- for (int row = 0; row < 2; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
- }
- }
-}
-
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_xs_f32")]]
-kernel void kernel_mul_mv_iq4_xs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
-kernel void kernel_get_rows_q(
- device const void * src0,
- device const void * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int64_t i10 = tgpig.x;
- const int64_t i11 = tgpig.y;
-
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
- const int64_t i02 = i11;
-
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
- float4x4 temp;
- dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
- }
-}
-
-template<typename T>
-kernel void kernel_get_rows_f(
- device const void * src0,
- device const void * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int64_t i10 = tgpig.x;
- const int64_t i11 = tgpig.y;
-
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
- const int64_t i02 = i11;
-
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
- }
-}
-
-kernel void kernel_get_rows_i32(
- device const void * src0,
- device const void * src1,
- device int32_t * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int64_t i10 = tgpig.x;
- const int64_t i11 = tgpig.y;
-
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
- const int64_t i02 = i11;
-
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
- }
-}
-
-
-#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
-#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
-#define BLOCK_SIZE_K 32
-#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
-#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
-#define THREAD_PER_BLOCK 128
-#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
-#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
-#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
-#define SG_MAT_ROW 8
-
-// each block_q contains 16*nl weights
-template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
-kernel void kernel_mul_mm(device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- threadgroup T * sa = (threadgroup T *)(shared_memory);
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
- const uint r0 = tgpig.y;
- const uint r1 = tgpig.x;
- const uint im = tgpig.z;
-
- // if this block is of 64x32 shape or smaller
- short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
- short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
-
- // a thread shouldn't load data outside of the matrix
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
-
- simdgroup_T8x8 ma[4];
- simdgroup_float8x8 mb[2];
- simdgroup_float8x8 mc[8];
-
- for (short i = 0; i < 8; i++){
- mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
- }
-
- short il = (tiitg % THREAD_PER_ROW);
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
- ushort offset1 = il/nl;
-
- device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
- device const float * y = (device const float *)(src1
- + nb13 * i13
- + nb12 * i12
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
-
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
- // load data and store to threadgroup memory
- T4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- #pragma unroll(16)
- for (short i = 0; i < 16; i++) {
- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
- + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
- + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
- }
-
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
-
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2+nl-1)/nl : x;
- y += BLOCK_SIZE_K;
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // load matrices from threadgroup memory and conduct outer products
- threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
- threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
-
- #pragma unroll(4)
- for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
- #pragma unroll(4)
- for (short i = 0; i < 4; i++) {
- simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
- }
- simdgroup_barrier(mem_flags::mem_none);
- #pragma unroll(2)
- for (short i = 0; i < 2; i++) {
- simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
- }
-
- lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
- lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
-
- #pragma unroll(8)
- for (short i = 0; i < 8; i++){
- simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
- }
- }
- }
-
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
- }
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
- + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (sgitg == 0) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
- device float4 * D4 = (device float4 *) D;
-
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
- threadgroup float4 * C4 = (threadgroup float4 *) C;
-
- int i = 0;
- for (; i < n_rows/4; i++) {
- *(D4 + i) = *(C4 + i);
- }
-
- i *= 4;
- for (; i < n_rows; i++) {
- *(D + i) = *(C + i);
- }
- }
- }
- }
-}
-
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-void kernel_mul_mm_id_impl(
- device const uchar * src0,
- device const uchar * src1,
- threadgroup ushort2 * rowids,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- int64_t ne1,
- int64_t ne0ne1,
- threadgroup uchar * shared_memory,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- threadgroup half * sa = (threadgroup half *)(shared_memory);
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
- const uint r0 = tgpig.y;
- const uint r1 = tgpig.x;
-
- if (r1 * BLOCK_SIZE_N >= ne1) return;
-
- // if this block is of 64x32 shape or smaller
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
-
- // a thread shouldn't load data outside of the matrix
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
-
- simdgroup_half8x8 ma[4];
- simdgroup_float8x8 mb[2];
- simdgroup_float8x8 c_res[8];
- for (int i = 0; i < 8; i++){
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
- }
- short il = (tiitg % THREAD_PER_ROW);
-
- ushort offset1 = il/nl;
-
- threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
-
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
- device const float * y = (device const float *)(src1
- + nb12 * id[1]
- + nb11 * (id[0] % ne11)
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
-
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
- // load data and store to threadgroup memory
- half4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- for (int i = 0; i < 16; i++) {
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
- }
-
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
-
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2+nl-1)/nl : x;
- y += BLOCK_SIZE_K;
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // load matrices from threadgroup memory and conduct outer products
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
-
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
- for (int i = 0; i < 4; i++) {
- simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
- }
- simdgroup_barrier(mem_flags::mem_none);
- for (int i = 0; i < 2; i++) {
- simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
- }
-
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
-
- for (int i = 0; i < 8; i++){
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
- }
- }
- }
-
- {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
- for (int i = 0; i < 8; i++) {
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- device float * C = dst + (BLOCK_SIZE_M * r0);
- if (sgitg == 0) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
- int joff = jid[0] * ne0 + jid[1] * ne0ne1;
- for (int i = 0; i < n_rows; i++) {
- *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
- }
- }
- }
- }
-}
-
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm_id(
- device const uchar * src0s,
- device const uchar * src1,
- device float * dst,
- device const uchar * ids,
- constant int64_t & nei0,
- constant int64_t & nei1,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- const int32_t i02 = tgpig.z;
- tgpig.z = 0;
-
- device const uchar * src0 = src0s + i02*nb02;
-
- // row indices
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
-
- // TODO: parallelize this loop
- int64_t _ne1 = 0;
- for (ushort ii1 = 0; ii1 < nei1; ii1++) {
- for (ushort ii0 = 0; ii0 < nei0; ii0++) {
- int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
- if (id == i02) {
- //if (tiitg == 0) {
- rowids[_ne1] = ushort2(ii0, ii1);
- //}
- _ne1++;
- }
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
- src0,
- src1,
- rowids,
- dst,
- ne00,
- ne02,
- nb01,
- nb02,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- _ne1,
- ne0*ne1,
- shared_memory,
- tgpig,
- tiitg,
- sgitg);
-}
-
-#define QK_NL 16
-
-//
-// get rows
-//
-
-typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
-
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
-#endif
-
-typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
-
-template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
-
-//
-// matrix-matrix multiplication
-//
-
-typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
-#endif
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
-
-//
-// indirect matrix-matrix multiplication
-//
-
-typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
-
-template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
-#endif
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
-
-//
-// matrix-vector multiplication
-//
-
-typedef void (kernel_mul_mv_impl_t)(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg);
-
-typedef void (kernel_mul_mv2_impl_t)(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg);
-
-template<kernel_mul_mv_impl_t impl_fn>
-void mmv_fn(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- int64_t ne13,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint64_t nb1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiitg,
- uint tiisg,
- uint sgitg) {
- impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
-}
-
-template<kernel_mul_mv2_impl_t impl_fn>
-void mmv_fn(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- int64_t ne13,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint64_t nb1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiitg,
- uint tiisg,
- uint sgitg) {
- impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
-}
-
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
-
-template<mul_mv_impl_fn_t impl_fn>
-kernel void kernel_mul_mv_id(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant int64_t & nei0,
- constant int64_t & nei1,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int iid1 = tgpig.z/nei0;
- const int idx = tgpig.z%nei0;
-
- tgpig.z = 0;
-
- const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
-
- const int64_t i11 = idx % ne11;
- const int64_t i12 = iid1;
-
- const int64_t i1 = idx;
- const int64_t i2 = i12;
-
- device const char * src0_cur = src0s + i02*nb02;
- device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
- device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
-
- impl_fn(
- /* src0 */ src0_cur,
- /* src1 */ src1_cur,
- /* dst */ dst_cur,
- /* ne00 */ ne00,
- /* ne01 */ ne01,
- /* ne02 */ 1, // ne02,
- /* nb00 */ nb00,
- /* nb01 */ nb01,
- /* nb02 */ nb02,
- /* nb03 */ nb02, // ne02 == 1
- /* ne10 */ ne10,
- /* ne11 */ 1, // ne11,
- /* ne12 */ 1, // ne12,
- /* ne13 */ 1, // ne13,
- /* nb10 */ nb10,
- /* nb11 */ nb11,
- /* nb12 */ nb12,
- /* ne13 */ nb12, // ne12 == 1
- /* ne0 */ ne0,
- /* ne1 */ 1, // ne1,
- /* nb1 */ nb1,
- /* r2 */ 1,
- /* r3 */ 1,
- shared_values,
- tgpig,
- tiitg,
- tiisg,
- sgitg);
-}
-
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
-
-template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
-#endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
-
-kernel void kernel_pool_2d_max_f32(
- device const float * src0,
- device float * dst,
- constant int32_t & k0,
- constant int32_t & k1,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int64_t & IH,
- constant int64_t & IW,
- constant int64_t & OH,
- constant int64_t & OW,
- constant int64_t & parallel_elements,
- uint gid[[thread_position_in_grid]]) {
-
- if (gid >= parallel_elements) {
- return;
- }
-
- const int idx = gid;
- const int I_HW = IH * IW;
- const int O_HW = OH * OW;
- const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / OW;
- const int cur_ow = idx % O_HW % OW;
-
- device const float * i_ptr = src0 + nc * I_HW;
- device float * o_ptr = dst + nc * O_HW;
-
- const int start_h = cur_oh * s1 - p1;
- const int bh = MAX(0, start_h);
- const int eh = MIN(IH, start_h + k1);
- const int start_w = cur_ow * s0 - p0;
- const int bw = MAX(0, start_w);
- const int ew = MIN(IW, start_w + k0);
-
- float res = -INFINITY;
-
- for (int i = bh; i < eh; i += 1) {
- for (int j = bw; j < ew; j += 1) {
- res = MAX(res, i_ptr[i * IW + j]);
- }
- }
-
- o_ptr[cur_oh * OW + cur_ow] = res;
-}
-
-kernel void kernel_pool_2d_avg_f32(
- device const float * src0,
- device float * dst,
- constant int32_t & k0,
- constant int32_t & k1,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int64_t & IH,
- constant int64_t & IW,
- constant int64_t & OH,
- constant int64_t & OW,
- constant int64_t & parallel_elements,
- uint gid[[thread_position_in_grid]]) {
-
- if (gid >= parallel_elements) {
- return;
- }
-
- const int idx = gid;
- const int I_HW = IH * IW;
- const int O_HW = OH * OW;
- const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / OW;
- const int cur_ow = idx % O_HW % OW;
-
- device const float * i_ptr = src0 + nc * I_HW;
- device float * o_ptr = dst + nc * O_HW;
-
- const int start_h = cur_oh * s1 - p1;
- const int bh = MAX(0, start_h);
- const int eh = MIN(IH, start_h + k1);
- const int start_w = cur_ow * s0 - p0;
- const int bw = MAX(0, start_w);
- const int ew = MIN(IW, start_w + k0);
- // const float scale = 1. / ((eh - bh) * (ew - bw));
- const float scale = 1. / (k0 * k1);
-
- float res = 0;
-
- for (int i = bh; i < eh; i += 1) {
- for (int j = bw; j < ew; j += 1) {
- float cur = i_ptr[i * IW + j];
- res += cur * scale;
- }
- }
-
- o_ptr[cur_oh * OW + cur_ow] = res;
-}
+++ /dev/null
-#include "ggml-rpc.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include <cinttypes>
-#include <string>
-#include <vector>
-#include <memory>
-#include <mutex>
-#include <unordered_map>
-#include <unordered_set>
-#ifdef _WIN32
-# define WIN32_LEAN_AND_MEAN
-# ifndef NOMINMAX
-# define NOMINMAX
-# endif
-# include <windows.h>
-# include <winsock2.h>
-#else
-# include <arpa/inet.h>
-# include <sys/socket.h>
-# include <sys/types.h>
-# include <netinet/in.h>
-# include <netinet/tcp.h>
-# include <netdb.h>
-# include <unistd.h>
-#endif
-#include <cstring>
-
-#define UNUSED GGML_UNUSED
-
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#ifdef _WIN32
-typedef SOCKET sockfd_t;
-using ssize_t = __int64;
-#else
-typedef int sockfd_t;
-#endif
-
-// cross-platform socket
-struct socket_t {
- sockfd_t fd;
- socket_t(sockfd_t fd) : fd(fd) {}
- ~socket_t() {
- GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
-#ifdef _WIN32
- closesocket(this->fd);
-#else
- close(this->fd);
-#endif
- }
-};
-
-// all RPC structures must be packed
-#pragma pack(push, 1)
-// ggml_tensor is serialized into rpc_tensor
-struct rpc_tensor {
- uint64_t id;
- uint32_t type;
- uint64_t buffer;
- uint32_t ne[GGML_MAX_DIMS];
- uint32_t nb[GGML_MAX_DIMS];
- uint32_t op;
- int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
- int32_t flags;
- uint64_t src[GGML_MAX_SRC];
- uint64_t view_src;
- uint64_t view_offs;
- uint64_t data;
- char name[GGML_MAX_NAME];
-
- char padding[4];
-};
-
-static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
-
-// RPC commands
-enum rpc_cmd {
- RPC_CMD_ALLOC_BUFFER = 0,
- RPC_CMD_GET_ALIGNMENT,
- RPC_CMD_GET_MAX_SIZE,
- RPC_CMD_BUFFER_GET_BASE,
- RPC_CMD_FREE_BUFFER,
- RPC_CMD_BUFFER_CLEAR,
- RPC_CMD_SET_TENSOR,
- RPC_CMD_GET_TENSOR,
- RPC_CMD_COPY_TENSOR,
- RPC_CMD_GRAPH_COMPUTE,
- RPC_CMD_GET_DEVICE_MEMORY,
- RPC_CMD_COUNT,
-};
-
-struct rpc_msg_alloc_buffer_req {
- uint64_t size;
-};
-
-struct rpc_msg_alloc_buffer_rsp {
- uint64_t remote_ptr;
- uint64_t remote_size;
-};
-
-struct rpc_msg_get_alignment_rsp {
- uint64_t alignment;
-};
-
-struct rpc_msg_get_max_size_rsp {
- uint64_t max_size;
-};
-
-struct rpc_msg_buffer_get_base_req {
- uint64_t remote_ptr;
-};
-
-struct rpc_msg_buffer_get_base_rsp {
- uint64_t base_ptr;
-};
-
-struct rpc_msg_free_buffer_req {
- uint64_t remote_ptr;
-};
-
-struct rpc_msg_buffer_clear_req {
- uint64_t remote_ptr;
- uint8_t value;
-};
-
-struct rpc_msg_get_tensor_req {
- rpc_tensor tensor;
- uint64_t offset;
- uint64_t size;
-};
-
-struct rpc_msg_copy_tensor_req {
- rpc_tensor src;
- rpc_tensor dst;
-};
-
-struct rpc_msg_copy_tensor_rsp {
- uint8_t result;
-};
-
-struct rpc_msg_graph_compute_rsp {
- uint8_t result;
-};
-
-struct rpc_msg_get_device_memory_rsp {
- uint64_t free_mem;
- uint64_t total_mem;
-};
-#pragma pack(pop)
-
-// RPC data structures
-
-static ggml_guid_t ggml_backend_rpc_guid() {
- static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
- return &guid;
-}
-
-struct ggml_backend_rpc_buffer_type_context {
- std::string endpoint;
- std::string name;
- size_t alignment;
- size_t max_size;
-};
-
-struct ggml_backend_rpc_context {
- std::string endpoint;
- std::string name;
-};
-
-struct ggml_backend_rpc_buffer_context {
- std::shared_ptr<socket_t> sock;
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
- uint64_t remote_ptr;
-};
-
-// RPC helper functions
-
-static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
-#ifdef _WIN32
- if (fd == INVALID_SOCKET) {
- return nullptr;
- }
-#else
- if (fd < 0) {
- return nullptr;
- }
-#endif
- return std::make_shared<socket_t>(fd);
-}
-
-static bool set_no_delay(sockfd_t sockfd) {
- int flag = 1;
- // set TCP_NODELAY to disable Nagle's algorithm
- int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
- return ret == 0;
-}
-
-static bool set_reuse_addr(sockfd_t sockfd) {
- int flag = 1;
- int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
- return ret == 0;
-}
-
-static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
- struct sockaddr_in addr;
- auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
- auto sock_ptr = make_socket(sockfd);
- if (sock_ptr == nullptr) {
- return nullptr;
- }
- if (!set_no_delay(sockfd)) {
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
- return nullptr;
- }
- addr.sin_family = AF_INET;
- addr.sin_port = htons(port);
- struct hostent * server = gethostbyname(host);
- if (server == NULL) {
- fprintf(stderr, "Cannot resolve host '%s'\n", host);
- return nullptr;
- }
- memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
- if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
- return nullptr;
- }
- return sock_ptr;
-}
-
-static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
- auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
- auto client_socket = make_socket(client_socket_fd);
- if (client_socket == nullptr) {
- return nullptr;
- }
- if (!set_no_delay(client_socket_fd)) {
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
- return nullptr;
- }
- return client_socket;
-}
-
-static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
- auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
- auto sock = make_socket(sockfd);
- if (sock == nullptr) {
- return nullptr;
- }
- if (!set_reuse_addr(sockfd)) {
- fprintf(stderr, "Failed to set SO_REUSEADDR\n");
- return nullptr;
- }
- if (inet_addr(host) == INADDR_NONE) {
- fprintf(stderr, "Invalid host address: %s\n", host);
- return nullptr;
- }
- struct sockaddr_in serv_addr;
- serv_addr.sin_family = AF_INET;
- serv_addr.sin_addr.s_addr = inet_addr(host);
- serv_addr.sin_port = htons(port);
-
- if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
- return nullptr;
- }
- if (listen(sockfd, 1) < 0) {
- return nullptr;
- }
- return sock;
-}
-
-static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
- size_t bytes_sent = 0;
- while (bytes_sent < size) {
- ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
- if (n < 0) {
- return false;
- }
- bytes_sent += n;
- }
- return true;
-}
-
-static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
- size_t bytes_recv = 0;
- while (bytes_recv < size) {
- ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
- if (n <= 0) {
- return false;
- }
- bytes_recv += n;
- }
- return true;
-}
-
-static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
- if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
- return false;
- }
- return send_data(sockfd, msg, msg_size);
-}
-
-static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
- uint64_t size;
- if (!recv_data(sockfd, &size, sizeof(size))) {
- return false;
- }
- if (size != msg_size) {
- return false;
- }
- return recv_data(sockfd, msg, msg_size);
-}
-
-static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
- uint64_t size;
- if (!recv_data(sockfd, &size, sizeof(size))) {
- return false;
- }
- try {
- input.resize(size);
- } catch (const std::bad_alloc & e) {
- fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
- return false;
- }
- return recv_data(sockfd, input.data(), size);
-}
-
-static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
- size_t pos = endpoint.find(':');
- if (pos == std::string::npos) {
- return false;
- }
- host = endpoint.substr(0, pos);
- port = std::stoi(endpoint.substr(pos + 1));
- return true;
-}
-
-// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
-// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
-static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
- uint8_t cmd_byte = cmd;
- if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
- return false;
- }
- if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
- return false;
- }
- if (!send_data(sock->fd, input, input_size)) {
- return false;
- }
- // TODO: currently the output_size is always known, do we need support for commands with variable output size?
- // even if we do, we can skip sending output_size from the server for commands with known output size
- uint64_t out_size;
- if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
- return false;
- }
- if (out_size != output_size) {
- return false;
- }
- if (!recv_data(sock->fd, output, output_size)) {
- return false;
- }
- return true;
-}
-
-// RPC client-side implementation
-
-static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
- static bool initialized = false;
-
- auto it = sockets.find(endpoint);
- if (it != sockets.end()) {
- if (auto sock = it->second.lock()) {
- return sock;
- }
- }
- std::string host;
- int port;
- if (!parse_endpoint(endpoint, host, port)) {
- return nullptr;
- }
-#ifdef _WIN32
- if (!initialized) {
- WSADATA wsaData;
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
- if (res != 0) {
- return nullptr;
- }
- initialized = true;
- }
-#else
- UNUSED(initialized);
-#endif
- auto sock = socket_connect(host.c_str(), port);
- if (sock == nullptr) {
- return nullptr;
- }
- GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
- sockets[endpoint] = sock;
- return sock;
-}
-
-static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_free_buffer_req request = {ctx->remote_ptr};
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
- GGML_ASSERT(status);
- delete ctx;
-}
-
-static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
- return ctx->base_cache[buffer];
- }
- rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
- rpc_msg_buffer_get_base_rsp response;
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
- GGML_ASSERT(status);
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
- ctx->base_cache[buffer] = base_ptr;
- return base_ptr;
-}
-
-static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
- rpc_tensor result;
- result.id = reinterpret_cast<uint64_t>(tensor);
- result.type = tensor->type;
- if (tensor->buffer) {
- ggml_backend_buffer_t buffer = tensor->buffer;
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- result.buffer = ctx->remote_ptr;
- } else {
- result.buffer = 0;
- }
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
- result.ne[i] = tensor->ne[i];
- result.nb[i] = tensor->nb[i];
- }
- result.op = tensor->op;
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
- result.op_params[i] = tensor->op_params[i];
- }
- result.flags = tensor->flags;
- for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
- result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
- }
- result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
- result.view_offs = tensor->view_offs;
- result.data = reinterpret_cast<uint64_t>(tensor->data);
- snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
- return result;
-}
-
-static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
- UNUSED(buffer);
- if (ggml_is_quantized(tensor->type)) {
- // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
- GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
- }
-}
-
-static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
- size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
- std::vector<uint8_t> input(input_size, 0);
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
- memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
- memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
- memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
- GGML_ASSERT(status);
-}
-
-static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_get_tensor_req request;
- request.tensor = serialize_tensor(tensor);
- request.offset = offset;
- request.size = size;
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
- GGML_ASSERT(status);
-}
-
-static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
- // check if src and dst are on the same server
- ggml_backend_buffer_t src_buffer = src->buffer;
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
- ggml_backend_buffer_t dst_buffer = dst->buffer;
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
- if (src_ctx->sock != dst_ctx->sock) {
- return false;
- }
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_copy_tensor_req request;
- request.src = serialize_tensor(src);
- request.dst = serialize_tensor(dst);
- rpc_msg_copy_tensor_rsp response;
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
- GGML_ASSERT(status);
- return response.result;
-}
-
-static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
- GGML_ASSERT(status);
-}
-
-static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
- /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
- /* .get_base = */ ggml_backend_rpc_buffer_get_base,
- /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_rpc_buffer_clear,
- /* .reset = */ NULL,
-};
-
-static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- return buft_ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- rpc_msg_alloc_buffer_req request = {size};
- rpc_msg_alloc_buffer_rsp response;
- auto sock = get_socket(buft_ctx->endpoint);
- bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
- GGML_ASSERT(status);
- if (response.remote_ptr != 0) {
- ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
- ggml_backend_rpc_buffer_interface,
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
- response.remote_size);
- return buffer;
- } else {
- return nullptr;
- }
-}
-
-static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
- rpc_msg_get_alignment_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
- GGML_ASSERT(status);
- return response.alignment;
-}
-
-static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- return buft_ctx->alignment;
-}
-
-static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
- rpc_msg_get_max_size_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
- GGML_ASSERT(status);
- return response.max_size;
-}
-
-static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- return buft_ctx->max_size;
-}
-
-static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- UNUSED(buft);
- return ggml_nbytes(tensor);
-}
-
-static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
- /* .get_name = */ ggml_backend_rpc_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_rpc_get_max_size,
- /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
- /* .is_host = */ NULL,
-};
-
-static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-
- return rpc_ctx->name.c_str();
-}
-
-static void ggml_backend_rpc_free(ggml_backend_t backend) {
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
- delete rpc_ctx;
- delete backend;
-}
-
-static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
- UNUSED(backend);
- // this is no-op because we don't have any async operations
-}
-
-static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
- if (tensor == nullptr) {
- return;
- }
- if (visited.find(tensor) != visited.end()) {
- return;
- }
- visited.insert(tensor);
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- add_tensor(tensor->src[i], tensors, visited);
- }
- add_tensor(tensor->view_src, tensors, visited);
- tensors.push_back(serialize_tensor(tensor));
-}
-
-static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
- uint32_t n_nodes = cgraph->n_nodes;
- std::vector<rpc_tensor> tensors;
- std::unordered_set<ggml_tensor*> visited;
- for (uint32_t i = 0; i < n_nodes; i++) {
- add_tensor(cgraph->nodes[i], tensors, visited);
- }
- // serialization format:
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
- uint32_t n_tensors = tensors.size();
- int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
- output.resize(output_size, 0);
- memcpy(output.data(), &n_nodes, sizeof(n_nodes));
- for (uint32_t i = 0; i < n_nodes; i++) {
- memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
- }
- uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
- *out_ntensors = n_tensors;
- rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
- memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
-}
-
-static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
- std::vector<uint8_t> input;
- serialize_graph(cgraph, input);
- rpc_msg_graph_compute_rsp response;
- auto sock = get_socket(rpc_ctx->endpoint);
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
- GGML_ASSERT(status);
- return (enum ggml_status)response.result;
-}
-
-static ggml_backend_i ggml_backend_rpc_interface = {
- /* .get_name = */ ggml_backend_rpc_name,
- /* .free = */ ggml_backend_rpc_free,
- /* .set_tensor_async = */ NULL,
- /* .get_tensor_async = */ NULL,
- /* .cpy_tensor_async = */ NULL,
- /* .synchronize = */ ggml_backend_rpc_synchronize,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_rpc_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- // NOTE: buffer types are allocated and never freed; this is by design
- static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
- auto it = buft_map.find(endpoint);
- if (it != buft_map.end()) {
- return it->second;
- }
- auto sock = get_socket(endpoint);
- if (sock == nullptr) {
- fprintf(stderr, "Failed to connect to %s\n", endpoint);
- return nullptr;
- }
- size_t alignment = get_alignment(sock);
- size_t max_size = get_max_size(sock);
- ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
- /* .alignment = */ alignment,
- /* .max_size = */ max_size
- };
-
- ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
- /* .iface = */ ggml_backend_rpc_buffer_type_interface,
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
- /* .context = */ buft_ctx
- };
- buft_map[endpoint] = buft;
- return buft;
-}
-
-ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
- ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
- };
-
- ggml_backend_t backend = new ggml_backend {
- /* .guid = */ ggml_backend_rpc_guid(),
- /* .interface = */ ggml_backend_rpc_interface,
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
- /* .context = */ ctx
- };
- return backend;
-}
-
-GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
-}
-
-static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
- rpc_msg_get_device_memory_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
- GGML_ASSERT(status);
- *free = response.free_mem;
- *total = response.total_mem;
-}
-
-GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
- auto sock = get_socket(endpoint);
- if (sock == nullptr) {
- *free = 0;
- *total = 0;
- return;
- }
- get_device_memory(sock, free, total);
-}
-
-// RPC server-side implementation
-
-class rpc_server {
-public:
- rpc_server(ggml_backend_t backend) : backend(backend) {}
- ~rpc_server();
-
- void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
- void get_alignment(rpc_msg_get_alignment_rsp & response);
- void get_max_size(rpc_msg_get_max_size_rsp & response);
- bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
- bool free_buffer(const rpc_msg_free_buffer_req & request);
- bool buffer_clear(const rpc_msg_buffer_clear_req & request);
- bool set_tensor(const std::vector<uint8_t> & input);
- bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
- bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
- bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
-
-private:
- ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
- ggml_tensor * create_node(uint64_t id,
- struct ggml_context * ctx,
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
-
-
- ggml_backend_t backend;
- std::unordered_set<ggml_backend_buffer_t> buffers;
-};
-
-void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
- response.remote_ptr = 0;
- response.remote_size = 0;
- if (buffer != nullptr) {
- response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
- response.remote_size = buffer->size;
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
- buffers.insert(buffer);
- } else {
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
- }
-}
-
-void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
- size_t alignment = ggml_backend_buft_get_alignment(buft);
- GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
- response.alignment = alignment;
-}
-
-void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
- size_t max_size = ggml_backend_buft_get_max_size(buft);
- GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
- response.max_size = max_size;
-}
-
-bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
- if (buffers.find(buffer) == buffers.end()) {
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
- return false;
- }
- void * base = ggml_backend_buffer_get_base(buffer);
- response.base_ptr = reinterpret_cast<uint64_t>(base);
- return true;
-}
-
-bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
- if (buffers.find(buffer) == buffers.end()) {
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
- return false;
- }
- ggml_backend_buffer_free(buffer);
- buffers.erase(buffer);
- return true;
-}
-
-bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
- if (buffers.find(buffer) == buffers.end()) {
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
- return false;
- }
- ggml_backend_buffer_clear(buffer, request.value);
- return true;
-}
-
-ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
- ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
- tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
- result->nb[i] = tensor->nb[i];
- }
- result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
- if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
- result->buffer = nullptr;
- }
-
- if (result->buffer) {
- // require that the tensor data does not go beyond the buffer end
- uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
- uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
- uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
- GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
- GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
- }
-
- result->op = (ggml_op) tensor->op;
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
- result->op_params[i] = tensor->op_params[i];
- }
- result->flags = tensor->flags;
- result->data = reinterpret_cast<void *>(tensor->data);
- ggml_set_name(result, tensor->name);
- return result;
-}
-
-
-bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
- // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
- if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
- return false;
- }
- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
- uint64_t offset;
- memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
- const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
-
- struct ggml_init_params params {
- /*.mem_size =*/ ggml_tensor_overhead(),
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
- struct ggml_context * ctx = ggml_init(params);
- ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
- if (tensor == nullptr) {
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
- ggml_free(ctx);
- return false;
- }
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
-
- // sanitize tensor->data
- {
- const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
- const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
-
- if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
- }
- }
-
- const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
- ggml_backend_tensor_set(tensor, data, offset, size);
- ggml_free(ctx);
- return true;
-}
-
-bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
- struct ggml_init_params params {
- /*.mem_size =*/ ggml_tensor_overhead(),
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
- struct ggml_context * ctx = ggml_init(params);
- ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
- if (tensor == nullptr) {
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
- ggml_free(ctx);
- return false;
- }
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
-
- // sanitize tensor->data
- {
- const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
- const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
-
- if (request.tensor.data + request.offset < p0 ||
- request.tensor.data + request.offset >= p1 ||
- request.size > (p1 - request.tensor.data - request.offset)) {
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
- }
- }
-
- response.resize(request.size, 0);
- ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
- ggml_free(ctx);
- return true;
-}
-
-bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
- struct ggml_init_params params {
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
- struct ggml_context * ctx = ggml_init(params);
- ggml_tensor * src = deserialize_tensor(ctx, &request.src);
- ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
- if (src == nullptr || dst == nullptr) {
- GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
- ggml_free(ctx);
- return false;
- }
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
- response.result = ggml_backend_buffer_copy_tensor(src, dst);
- ggml_free(ctx);
- return true;
-}
-
-ggml_tensor * rpc_server::create_node(uint64_t id,
- struct ggml_context * ctx,
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
- if (id == 0) {
- return nullptr;
- }
- if (tensor_map.find(id) != tensor_map.end()) {
- return tensor_map[id];
- }
- const rpc_tensor * tensor = tensor_ptrs.at(id);
- struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
- if (result == nullptr) {
- return nullptr;
- }
- tensor_map[id] = result;
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
- }
- result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
- result->view_offs = tensor->view_offs;
- return result;
-}
-
-bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
- // serialization format:
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
- if (input.size() < sizeof(uint32_t)) {
- return false;
- }
- uint32_t n_nodes;
- memcpy(&n_nodes, input.data(), sizeof(n_nodes));
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
- return false;
- }
- const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
- uint32_t n_tensors;
- memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
- return false;
- }
- const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
- GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
-
- size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_size,
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
- struct ggml_context * ctx = ggml_init(params);
- struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
- graph->n_nodes = n_nodes;
- std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
- for (uint32_t i = 0; i < n_tensors; i++) {
- tensor_ptrs[tensors[i].id] = &tensors[i];
- }
- std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
- for (uint32_t i = 0; i < n_nodes; i++) {
- int64_t id;
- memcpy(&id, &nodes[i], sizeof(id));
- graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
- }
- ggml_status status = ggml_backend_graph_compute(backend, graph);
- response.result = status;
- ggml_free(ctx);
- return true;
-}
-
-rpc_server::~rpc_server() {
- for (auto buffer : buffers) {
- ggml_backend_buffer_free(buffer);
- }
-}
-
-static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
- rpc_server server(backend);
- while (true) {
- uint8_t cmd;
- if (!recv_data(sockfd, &cmd, 1)) {
- break;
- }
- if (cmd >= RPC_CMD_COUNT) {
- // fail fast if the command is invalid
- fprintf(stderr, "Unknown command: %d\n", cmd);
- break;
- }
- switch (cmd) {
- case RPC_CMD_ALLOC_BUFFER: {
- rpc_msg_alloc_buffer_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- rpc_msg_alloc_buffer_rsp response;
- server.alloc_buffer(request, response);
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_GET_ALIGNMENT: {
- if (!recv_msg(sockfd, nullptr, 0)) {
- return;
- }
- rpc_msg_get_alignment_rsp response;
- server.get_alignment(response);
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_GET_MAX_SIZE: {
- if (!recv_msg(sockfd, nullptr, 0)) {
- return;
- }
- rpc_msg_get_max_size_rsp response;
- server.get_max_size(response);
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_BUFFER_GET_BASE: {
- rpc_msg_buffer_get_base_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- rpc_msg_buffer_get_base_rsp response;
- if (!server.buffer_get_base(request, response)) {
- return;
- }
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_FREE_BUFFER: {
- rpc_msg_free_buffer_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- if (!server.free_buffer(request)) {
- return;
- }
- if (!send_msg(sockfd, nullptr, 0)) {
- return;
- }
- break;
- }
- case RPC_CMD_BUFFER_CLEAR: {
- rpc_msg_buffer_clear_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- if (!server.buffer_clear(request)) {
- return;
- }
- if (!send_msg(sockfd, nullptr, 0)) {
- return;
- }
- break;
- }
- case RPC_CMD_SET_TENSOR: {
- std::vector<uint8_t> input;
- if (!recv_msg(sockfd, input)) {
- return;
- }
- if (!server.set_tensor(input)) {
- return;
- }
- if (!send_msg(sockfd, nullptr, 0)) {
- return;
- }
- break;
- }
- case RPC_CMD_GET_TENSOR: {
- rpc_msg_get_tensor_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- std::vector<uint8_t> response;
- if (!server.get_tensor(request, response)) {
- return;
- }
- if (!send_msg(sockfd, response.data(), response.size())) {
- return;
- }
- break;
- }
- case RPC_CMD_COPY_TENSOR: {
- rpc_msg_copy_tensor_req request;
- if (!recv_msg(sockfd, &request, sizeof(request))) {
- return;
- }
- rpc_msg_copy_tensor_rsp response;
- if (!server.copy_tensor(request, response)) {
- return;
- }
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_GRAPH_COMPUTE: {
- std::vector<uint8_t> input;
- if (!recv_msg(sockfd, input)) {
- return;
- }
- rpc_msg_graph_compute_rsp response;
- if (!server.graph_compute(input, response)) {
- return;
- }
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- case RPC_CMD_GET_DEVICE_MEMORY: {
- if (!recv_msg(sockfd, nullptr, 0)) {
- return;
- }
- rpc_msg_get_device_memory_rsp response;
- response.free_mem = free_mem;
- response.total_mem = total_mem;
- if (!send_msg(sockfd, &response, sizeof(response))) {
- return;
- }
- break;
- }
- default: {
- fprintf(stderr, "Unknown command: %d\n", cmd);
- return;
- }
- }
- }
-}
-
-void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
- std::string host;
- int port;
- if (!parse_endpoint(endpoint, host, port)) {
- return;
- }
-#ifdef _WIN32
- {
- WSADATA wsaData;
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
- if (res != 0) {
- fprintf(stderr, "WSAStartup failed: %d\n", res);
- return;
- }
- }
-#endif
- auto server_socket = create_server_socket(host.c_str(), port);
- if (server_socket == nullptr) {
- fprintf(stderr, "Failed to create server socket\n");
- return;
- }
- while (true) {
- auto client_socket = socket_accept(server_socket->fd);
- if (client_socket == nullptr) {
- fprintf(stderr, "Failed to accept client connection\n");
- return;
- }
- printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
- fflush(stdout);
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
- printf("Client connection closed\n");
- fflush(stdout);
- }
-#ifdef _WIN32
- WSACleanup();
-#endif
-}
-
-// device interface
-
-struct ggml_backend_rpc_device_context {
- std::string endpoint;
- std::string name;
-};
-
-static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
- return ctx->name.c_str();
-}
-
-static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
- return ctx->name.c_str();
-}
-
-static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
- ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
-
- UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
- // TODO: obtain value from the server
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-
- UNUSED(dev);
-}
-
-static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_rpc_device_get_name(dev);
- props->description = ggml_backend_rpc_device_get_description(dev);
- props->type = ggml_backend_rpc_device_get_type(dev);
- ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
- props->caps = {
- /* .async = */ false,
- /* .host_buffer = */ false,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
- return ggml_backend_rpc_init(ctx->endpoint.c_str());
-
- UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
-
- UNUSED(dev);
-}
-
-static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- UNUSED(dev);
- UNUSED(op);
- //TODO: call the remote backend and cache the results
- return true;
-}
-
-static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
- return false;
- }
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
- return buft_ctx->endpoint == dev_ctx->endpoint;
-}
-
-static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
- /* .get_name = */ ggml_backend_rpc_device_get_name,
- /* .get_description = */ ggml_backend_rpc_device_get_description,
- /* .get_memory = */ ggml_backend_rpc_device_get_memory,
- /* .get_type = */ ggml_backend_rpc_device_get_type,
- /* .get_props = */ ggml_backend_rpc_device_get_props,
- /* .init_backend = */ ggml_backend_rpc_device_init,
- /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
- /* .get_host_buffer_type = */ NULL,
- /* .buffer_from_host_ptr = */ NULL,
- /* .supports_op = */ ggml_backend_rpc_device_supports_op,
- /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
- /* .offload_op = */ NULL,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
- return "RPC";
-
- UNUSED(reg);
-}
-
-static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
- return 0;
-
- UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
-
- UNUSED(reg);
- UNUSED(index);
-}
-
-static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
- return (void *)ggml_backend_rpc_add_device;
- }
- return NULL;
-
- UNUSED(reg);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
- /* .get_name = */ ggml_backend_rpc_reg_get_name,
- /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
- /* .get_device = */ ggml_backend_rpc_reg_get_device,
- /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_rpc_reg(void) {
- static struct ggml_backend_reg ggml_backend_rpc_reg = {
- /* .iface = */ ggml_backend_rpc_reg_i,
- /* .context = */ NULL,
- };
-
- return &ggml_backend_rpc_reg;
-}
-
-ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
- static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
-
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- if (dev_map.find(endpoint) != dev_map.end()) {
- return dev_map[endpoint];
- }
-
- ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
- };
-
- ggml_backend_dev_t dev = new ggml_backend_device {
- /* .iface = */ ggml_backend_rpc_device_i,
- /* .reg = */ ggml_backend_rpc_reg(),
- /* .context = */ ctx,
- };
-
- dev_map[endpoint] = dev;
-
- return dev;
-}
+++ /dev/null
-//
-// MIT license
-// Copyright (C) 2024 Intel Corporation
-// SPDX-License-Identifier: MIT
-//
-
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-
-#include <algorithm>
-#include <assert.h>
-#include <atomic>
-#include <cinttypes>
-#include <cstddef>
-#include <cstdint>
-#include <cstdlib>
-#include <float.h>
-#include <limits>
-#include <stdint.h>
-#include <stdio.h>
-#include <vector>
-#include <cmath>
-#include <iostream>
-#include <fstream>
-#include <stdio.h>
-#include <stdlib.h>
-#include <regex>
-
-#include <sycl/sycl.hpp>
-#include <sycl/half_type.hpp>
-
-#include "ggml-sycl.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-sycl/backend.hpp"
-#include "ggml-sycl/presets.hpp"
-#include "ggml-sycl/gemm.hpp"
-
-static bool g_sycl_loaded = false;
-
-static ggml_sycl_device_info ggml_sycl_init() {
- ggml_sycl_device_info info = {};
-
- info.device_count = dpct::dev_mgr::instance().device_count();
- if (info.device_count == 0) {
- fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__);
- return info;
- }
-
- GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
-
- int64_t total_vram = 0;
-#if defined(GGML_SYCL_FORCE_MMQ)
- fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
-#else
- fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
-#endif
-#if defined(SYCL_USE_XMX)
- fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
- fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
- fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count);
-
- for (int i = 0; i < info.device_count; ++i) {
- info.devices[i].vmm = 0;
- dpct::device_info prop;
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
- prop, dpct::dev_mgr::instance().get_device(i))));
-
- info.default_tensor_split[i] = total_vram;
- total_vram += prop.get_global_mem_size();
-
- info.devices[i].cc =
- 100 * prop.get_major_version() + 10 * prop.get_minor_version();
-
- info.max_work_group_sizes[i] = prop.get_max_work_group_size();
- }
-
- for (int id = 0; id < info.device_count; ++id) {
- info.default_tensor_split[id] /= total_vram;
- }
- return info;
-}
-
-const ggml_sycl_device_info & ggml_sycl_info() {
- static ggml_sycl_device_info info = ggml_sycl_init();
- return info;
-}
-
-void print_device_detail(int id, sycl::device &device, std::string device_type) {
-
- dpct::device_info prop;
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::get_device_info(prop, device)));
-
- std::string version;
- version += std::to_string(prop.get_major_version());
- version += ".";
- version += std::to_string(prop.get_minor_version());
-
- device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
- std::string name = std::string(prop.get_name());
- name = std::regex_replace(name, std::regex("\\(R\\)"), "");
- name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
-
- auto global_mem_size = prop.get_global_mem_size()/1000000;
-
- fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
- name.c_str(), version.c_str(), prop.get_max_compute_units(),
- prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
- global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
-}
-
-void ggml_backend_sycl_print_sycl_devices() {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
- int device_count = dpct::dev_mgr::instance().device_count();
- std::map<std::string, size_t> DeviceNums;
- fprintf(stderr, "found %d SYCL devices:\n", device_count);
- fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
- fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
- fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
- fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
- for (int id = 0; id < device_count; ++id) {
- sycl::device device = dpct::dev_mgr::instance().get_device(id);
- sycl::backend backend = device.get_backend();
- std::string backend_type = get_device_backend_and_type(device);
- int type_id=DeviceNums[backend_type]++;
- std::stringstream device_type;
- device_type << "[" << backend_type << ":" << std::to_string(type_id) << "]";
- print_device_detail(id, device, device_type.str());
- }
-}
-
-static inline int get_sycl_env(const char *env_name, int default_val) {
- char *user_device_string = getenv(env_name);
- int user_number = default_val;
-
- unsigned n;
- if (user_device_string != NULL &&
- sscanf(user_device_string, " %u", &n) == 1) {
- user_number = (int)n;
- } else {
- user_number = default_val;
- }
- return user_number;
-}
-
-static void ggml_check_sycl() try {
- static bool initialized = false;
-
- if (!initialized) {
- fprintf(stderr, "[SYCL] call ggml_check_sycl\n");
- g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
-
- fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
-
-#if defined(GGML_SYCL_F16)
- fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__);
-#else
- fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
-#endif
-
-/* NOT REMOVE, keep it for next optimize for XMX.
-#if defined(SYCL_USE_XMX)
- fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
- fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-*/
-
- if (CHECK_TRY_ERROR(g_all_sycl_device_count =
- dpct::dev_mgr::instance().device_count()) != 0) {
- initialized = true;
- g_sycl_loaded = false;
- return;
- }
- GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
- ggml_backend_sycl_print_sycl_devices();
- initialized = true;
- g_sycl_loaded = true;
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-/*
-device_index: device index from 0 to n (continue numbers).
- It is used for device select/set in SYCL backend internal data structure.
-*/
-inline void check_allow_gpu_index(const int device_index) {
- if (device_index >= ggml_sycl_info().device_count) {
- char error_buf[256];
- snprintf(
- error_buf,
- sizeof(error_buf),
- "%s error: device_index:%d is out of range: [0-%d]",
- __func__,
- device_index,
- ggml_sycl_info().device_count - 1);
- fprintf(stderr, "%s\n", error_buf);
- assert(false);
- }
-}
-
-GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
- for(int i=0;i<max_len;i++) id_list[i] = -1;
-
- for (int i=0;i< ggml_sycl_info().device_count;i++){
- if (i>=max_len) break;
- id_list[i] = i;
- }
- return;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-// sycl buffer
-
-struct ggml_backend_sycl_buffer_context {
- int device;
- void * dev_ptr = nullptr;
- queue_ptr stream;
- std::string name;
-
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
- device(device), dev_ptr(dev_ptr), stream(stream) {
- check_allow_gpu_index(device);
- name = (GGML_SYCL_NAME + std::to_string(device));
- }
-
-
- ~ggml_backend_sycl_buffer_context() {
- if (dev_ptr != nullptr) {
- ggml_sycl_set_device(device);
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
- }
- }
-};
-
-static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);
-
-static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
- return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;
-}
-
-static void
-ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
- ggml_sycl_set_device(ctx->device);
-
- delete ctx;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
- return ctx->dev_ptr;
-}
-
-static void
-ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
- ggml_tensor *tensor) try {
- ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
-
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
- assert(tensor->view_src->buffer->buft == buffer->buft);
- tensor->backend = tensor->view_src->backend;
- tensor->extra = tensor->view_src->extra;
- return;
- }
-
-
- if (ggml_is_quantized(tensor->type)) {
- // initialize padding to 0 to avoid possible NaN values
- size_t original_size = ggml_nbytes(tensor);
- size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
- if (padded_size > original_size && tensor->view_src == nullptr) {
- SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
- (char *)tensor->data + original_size, 0,
- padded_size - original_size).wait()));
- }
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
- ggml_tensor *tensor,
- const void *data, size_t offset,
- size_t size) try {
-
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
- ggml_sycl_set_device(ctx->device);
- auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
- SYCL_CHECK(
- CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
- char* host_buf = (char*)malloc(size);
- memcpy(host_buf, data, size);
- SYCL_CHECK(
- CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
- .wait()));
- free(host_buf);
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
- const ggml_tensor *tensor,
- void *data, size_t offset,
- size_t size) try {
-
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
- ggml_sycl_set_device(ctx->device);
- auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
-
- SYCL_CHECK(CHECK_TRY_ERROR(
- stream.memcpy(data, (const char *)tensor->data + offset, size)
- .wait()));
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
- const void *ptr_src, size_t size) {
- char *host_buf = (char *)malloc(size);
- q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
- q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
- free(host_buf);
-}
-
-static bool
-ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
- const ggml_tensor *src,
- ggml_tensor *dst) try {
- if (ggml_backend_buffer_is_sycl(src->buffer)) {
- ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
- ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
-
- ggml_sycl_set_device(src_ctx->device);
- /*
- DPCT1009:198: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
- ggml_sycl_set_device(dst_ctx->device);
- /*
- DPCT1009:199: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
- /*
- DPCT1009:200: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
-
- queue_ptr stream_dst = dst_ctx->stream;
- queue_ptr stream_src = src_ctx->stream;
- size_t size = ggml_nbytes(src);
-
- //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
- dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
-
-//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove
-#if 0
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
- (char *)dst->data, (const char *)src->data, size).wait()));
-
- /*
- DPCT1009:201: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
-#endif
- return true;
- }
- return false;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-
-static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
- uint8_t value) try {
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
- ggml_sycl_set_device(ctx->device);
- queue_ptr stream = ctx->stream;
- SYCL_CHECK(
- CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
-
- SYCL_CHECK(CHECK_TRY_ERROR((*stream)
- .memset(ctx->dev_ptr, value, buffer->size)
- .wait()));
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
- /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
- /* .get_base = */ ggml_backend_sycl_buffer_get_base,
- /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_sycl_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// sycl buffer type
-struct ggml_backend_sycl_buffer_type_context {
- int device;
- std::string name;
-
- // each buffer type has its own stream
- queue_ptr stream = nullptr;
-};
-
-static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-
- return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t
-ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
- size_t size) try {
- ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
- ggml_sycl_set_device(buft_ctx->device);
- const queue_ptr stream = buft_ctx->stream;
- size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
-
- void * dev_ptr;
- SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
- size, *stream)));
- if (!dev_ptr) {
- fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, size);
- return nullptr;
- }
- ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
- return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return 128;
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
- return dpct::get_current_device().get_max_mem_alloc_size();
-
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- size_t size = ggml_nbytes(tensor);
- int64_t ne0 = tensor->ne[0];
-
- if (ggml_is_quantized(tensor->type)) {
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
- }
-
- return size;
-
- GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
- /* .get_name = */ ggml_backend_sycl_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size,
- /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size,
- /* .is_host = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
-
- auto dev_count = ggml_backend_sycl_get_device_count();
-
- if (device>=dev_count or device<0) {
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
- device, dev_count-1);
- GGML_ASSERT(device<dev_count);
- }
- static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
-
- static bool ggml_backend_sycl_buffer_type_initialized = false;
-
- if (!ggml_backend_sycl_buffer_type_initialized) {
- for (int i = 0; i < dev_count; i++) {
- auto & device_i = dpct::dev_mgr::instance().get_device(i);
- queue_ptr stream = &(device_i.default_queue());
- ggml_backend_sycl_buffer_types[i] = {
- /* .iface = */ ggml_backend_sycl_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
- /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
- };
- }
- ggml_backend_sycl_buffer_type_initialized = true;
- }
- return &ggml_backend_sycl_buffer_types[device];
-}
-
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
-
- int device = ctx->device;
- if (device>=ggml_sycl_info().device_count or device<0) {
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
- device, ggml_sycl_info().device_count-1);
- GGML_ASSERT(device<ggml_sycl_info().device_count);
- }
- static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
-
- static bool ggml_backend_sycl_buffer_type_initialized = false;
-
- if (!ggml_backend_sycl_buffer_type_initialized) {
- for (int i = 0; i < ggml_sycl_info().device_count; i++) {
- ggml_backend_sycl_buffer_types[i] = {
- /* .iface = */ ggml_backend_sycl_buffer_type_interface,
- /* .device = */ nullptr,
- /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
- };
- }
- ggml_backend_sycl_buffer_type_initialized = true;
- }
- return &ggml_backend_sycl_buffer_types[device];
-}
-
-// sycl split buffer
-
-static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
- int64_t min_compute_capability = INT_MAX;
- int64_t max_compute_capability = INT_MIN;
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
- if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
- min_compute_capability = ggml_sycl_info().devices[i].cc;
- }
- if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
- max_compute_capability = ggml_sycl_info().devices[i].cc;
- }
- }
- }
-
- switch(type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- return max_compute_capability >= VER_GEN9 ? 128 : 64;
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- return 64;
- case GGML_TYPE_F16:
- case GGML_TYPE_F32:
- return 1;
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_IQ2_XXS:
- case GGML_TYPE_IQ2_XS:
- case GGML_TYPE_IQ2_S:
- case GGML_TYPE_IQ1_S:
- case GGML_TYPE_IQ1_M:
- case GGML_TYPE_IQ3_XXS:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ4_NL:
- return max_compute_capability >= VER_GEN9 ? 128 : 64;
- case GGML_TYPE_IQ3_S:
- return max_compute_capability >= VER_GEN9 ? 128 : 64;
- case GGML_TYPE_Q6_K:
- return 64;
- default:
- GGML_ABORT("fatal error");
- }
-}
-
-static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
- const int64_t nrows = ggml_nrows(tensor);
- const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
-
- *row_low = id == 0 ? 0 : nrows*tensor_split[id];
- *row_low -= *row_low % rounding;
- if (id == ggml_sycl_info().device_count - 1) {
- *row_high = nrows;
- } else {
- *row_high = nrows*tensor_split[id + 1];
- *row_high -= *row_high % rounding;
- }
-}
-
-static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
- static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
-
- return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
-}
-
-struct ggml_backend_sycl_split_buffer_type_context {
- std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
-};
-
-struct ggml_backend_sycl_split_buffer_context {
- ~ggml_backend_sycl_split_buffer_context() try {
- for (ggml_tensor_extra_gpu * extra : tensor_extras) {
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
- if (extra->events[i][is] != nullptr) {
- /*
- DPCT1009:206: SYCL uses exceptions to report errors and
- does not use the error codes. The original code was
- commented out and a warning string was inserted. You
- need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::destroy_event(extra->events[i][is])));
- }
- }
- if (extra->data_device[i] != nullptr) {
- /*
- DPCT1009:207: SYCL uses exceptions to report errors and does
- not use the error codes. The original code was commented out
- and a warning string was inserted. You need to rewrite this
- code.
- */
- ggml_sycl_set_device(i);
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
- extra->data_device[i], *(streams[i]))));
- }
- }
- delete extra;
- }
- }
- catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
- }
-
- std::vector<ggml_tensor_extra_gpu *> tensor_extras;
- std::vector<queue_ptr> streams;
-};
-
-static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
- delete ctx;
-}
-
-static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
- // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
- return (void *)0x1000;
-
- GGML_UNUSED(buffer);
-}
-
-static void
-ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
- ggml_tensor *tensor) try {
- GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
-
- ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
-
- ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
-
- ctx->tensor_extras.push_back(extra);
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- // FIXME: do not crash if cudaMalloc fails
- // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
- ggml_sycl_set_device(i);
- const queue_ptr stream = ctx->streams[i];
- char * buf;
- /*
- DPCT1009:208: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
- size, *stream)));
- if (!buf) {
- char err_buf[1024];
- snprintf(err_buf, 1023, "%s: can't malloc %lu Bytes memory on device", __func__, size);
- throw std::runtime_error(err_buf);
- }
- // set padding to 0 to avoid possible NaN values
- if (size > original_size) {
- /*
- DPCT1009:209: SYCL uses exceptions to report errors and does not use
- the error codes. The original code was commented out and a warning
- string was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- (*stream)
- .memset(buf + original_size, 0, size - original_size)
- .wait()));
- }
-
- extra->data_device[i] = buf;
-
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
- /*
- DPCT1009:210: SYCL uses exceptions to report errors and does not use
- the error codes. The original code was commented out and a warning
- string was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(
- CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
- }
- }
- tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
- tensor->extra = extra;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void
-ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
- ggml_tensor *tensor, const void *data,
- size_t offset, size_t size) try {
- // split tensors must always be set in their entirety at once
- GGML_ASSERT(offset == 0);
- GGML_ASSERT(size == ggml_nbytes(tensor));
-
- ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
- const size_t nb1 = tensor->nb[1];
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- const size_t offset_split = row_low*nb1;
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- const char * buf_host = (const char *)data + offset_split;
- /*
- DPCT1009:211: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- ggml_sycl_set_device(i);
- const queue_ptr stream = ctx->streams[i];
- SYCL_CHECK(CHECK_TRY_ERROR(
- (*stream)
- .memcpy(extra->data_device[i], buf_host, original_size)
- .wait()));
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void
-ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
- const ggml_tensor *tensor, void *data,
- size_t offset, size_t size) try {
- // split tensors must always be set in their entirety at once
- GGML_ASSERT(offset == 0);
- GGML_ASSERT(size == ggml_nbytes(tensor));
-
- ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
- const int64_t ne0 = tensor->ne[0];
- const size_t nb1 = tensor->nb[1];
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- const size_t offset_split = row_low*nb1;
- size_t size = ggml_nbytes_split(tensor, nrows_split);
- const size_t original_size = size;
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
-
- char * buf_host = (char *)data + offset_split;
- /*
- DPCT1009:212: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- ggml_sycl_set_device(i);
- const queue_ptr stream = ctx->streams[i];
- SYCL_CHECK(CHECK_TRY_ERROR(
- (*stream)
- .memcpy(buf_host, extra->data_device[i], original_size)
- .wait()));
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- GGML_UNUSED(buffer);
- GGML_UNUSED(value);
-}
-
-static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
- /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer,
- /* .get_base = */ ggml_backend_sycl_split_buffer_get_base,
- /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor,
- /* .cpy_tensor = */ NULL,
- /* .clear = */ ggml_backend_sycl_split_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// sycl split buffer type
-
-static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return GGML_SYCL_NAME "_Split";
-
- GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
- return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
- // instead, we allocate them for each tensor separately in init_tensor
- // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
- // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
- ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
-
- return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return 128;
- GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
-
- size_t total_size = 0;
-
- const int64_t ne0 = tensor->ne[0];
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- int64_t row_low, row_high;
- get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
-
- int64_t nrows_split = row_high - row_low;
- if (nrows_split == 0) {
- continue;
- }
-
- total_size += ggml_nbytes_split(tensor, nrows_split);
-
- // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
- if (ne0 % MATRIX_ROW_PADDING != 0) {
- total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
- }
- }
-
- return total_size;
-}
-
-static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
- return false;
-
- GGML_UNUSED(buft);
-}
-
-static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
- /* .get_name = */ ggml_backend_sycl_split_buffer_type_get_name,
- /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
- /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host,
-};
-
-ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
-
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
- ggml_check_sycl();
- // FIXME: this is not thread safe
- static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
-
- std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
-
- bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
- if (all_zero) {
- tensor_split_arr = ggml_sycl_info().default_tensor_split;
- } else {
- float split_sum = 0.0f;
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- tensor_split_arr[i] = split_sum;
- split_sum += tensor_split[i];
- }
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- tensor_split_arr[i] /= split_sum;
- }
- }
-
- auto it = buft_map.find(tensor_split_arr);
- if (it != buft_map.end()) {
- return &it->second;
- }
-
- struct ggml_backend_buffer_type buft {
- /* .iface = */ ggml_backend_sycl_split_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
- /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
- };
-
- auto result = buft_map.emplace(tensor_split_arr, buft);
- return &result.first->second;
-}
-
-// host buffer type
-
-static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
- return GGML_SYCL_NAME "_Host";
-
- GGML_UNUSED(buft);
-}
-
-static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- ggml_sycl_host_free(buffer->context);
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- void * ptr = ggml_sycl_host_malloc(size);
-
- if (ptr == nullptr) {
- // fallback to cpu buffer
- return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
- }
-
- // FIXME: this is a hack to avoid having to implement a new buffer type
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
- buffer->buft = buft;
- buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
-
- return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
- static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_sycl_host_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
- /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
- /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
- },
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
- /* .context = */ nullptr,
- };
-
- return &ggml_backend_sycl_buffer_type_host;
-}
-
-// buffer pool for sycl (legacy)
-struct ggml_sycl_pool_leg : public ggml_sycl_pool {
- static const int MAX_SYCL_BUFFERS = 256;
-
- int device;
- queue_ptr qptr;
- struct ggml_sycl_buffer {
- void * ptr = nullptr;
- size_t size = 0;
- };
-
- ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
- size_t pool_size = 0;
-
- explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) :
- qptr(qptr_),
- device(device_) {
- }
-
- ~ggml_sycl_pool_leg() {
- for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
- ggml_sycl_buffer & b = buffer_pool[i];
- if (b.ptr != nullptr) {
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
- pool_size -= b.size;
- }
- }
- GGML_ASSERT(pool_size == 0);
- }
-
- void * alloc(size_t size, size_t * actual_size) override {
-#ifdef DEBUG_sycl_MALLOC
- int nnz = 0;
- size_t max_size = 0;
-#endif
- size_t best_diff = 1ull << 36;
- int ibest = -1;
- for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
- ggml_sycl_buffer& b = buffer_pool[i];
- if (b.ptr != nullptr) {
-#ifdef DEBUG_sycl_MALLOC
- ++nnz;
- if (b.size > max_size) max_size = b.size;
-#endif
- if (b.size >= size) {
- size_t diff = b.size - size;
- if (diff < best_diff) {
- best_diff = diff;
- ibest = i;
- if (!best_diff) {
- void * ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- }
- }
- }
- }
- if (ibest >= 0) {
- ggml_sycl_buffer& b = buffer_pool[ibest];
- void * ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
- }
- void * ptr;
- size_t look_ahead_size = (size_t) (1.05 * size);
-
- SYCL_CHECK(
- CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
- look_ahead_size, *qptr)));
- if (!ptr) {
- fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, look_ahead_size);
- return nullptr;
- }
-
- *actual_size = look_ahead_size;
- pool_size += look_ahead_size;
-
- #ifdef DEBUG_SYCL_MALLOC
- fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
- (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
- #endif
- // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
- return ptr;
- }
-
- void free(void * ptr, size_t size) override {
- for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
- ggml_sycl_buffer& b = buffer_pool[i];
- if (b.ptr == nullptr) {
- b.ptr = ptr;
- b.size = size;
- return;
- }
- }
- fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
- pool_size -= size;
- }
-};
-
-std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
- // TBD: NO VMM support
- // if (ggml_sycl_info().devices[device].vmm) {
- // return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
- // }
- return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
-}
-
-// TBD pool with virtual memory management
-// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
-
-/// kernels
-
-typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
-typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_sycl_op_mul_mat_t)(
- ggml_backend_sycl_context & ctx,
- const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
- const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
- float *dst_dd_i, const int64_t row_low, const int64_t row_high,
- const int64_t src1_ncols, const int64_t src1_padded_row_size,
- const queue_ptr &stream);
-
-
-
-template<int QUANT_BLOCK_TILE>
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
- const sycl::nd_item<3> &item_ct1) {
- const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
-
- if (ix >= kx_padded) {
- return;
- }
-
- const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
-
- const int i_padded = iy*kx_padded + ix;
-
- block_q8_1 * y = (block_q8_1 *) vy;
-
- const int ib = i_padded / QK8_1; // block index
- const int iqs = i_padded % QK8_1; // quant index
- typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
- typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
- TC zeros;
- TQ qzeros;
-#pragma unroll
- for (int i = 0; i < QUANT_BLOCK_TILE; i++)
- {
- zeros[i] = 0.f;
- qzeros[i] = 0;
- }
- const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
- float sum = xi[0];
- float amax = sycl::fabs(xi[0]);
-#pragma unroll
- for (int i = 1; i < QUANT_BLOCK_TILE; i++)
- {
- sum += xi[i];
- amax = sycl::fmax(sycl::fabs(xi[i]), amax);
- }
- sum = warp_reduce_sum(sum, item_ct1);
- amax = warp_reduce_max(amax, item_ct1);
-
- const float d = amax / 127;
- TQ q = qzeros;
- if (amax != 0.0f)
- {
-#pragma unroll
- for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
- q[i] = sycl::round(xi[i] / d);
- }
- }
-
- *(TQ *)&y[ib].qs[iqs] = q;
-
- if (iqs > 0) {
- return;
- }
-
- reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
- reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
-}
-
-template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void k_get_rows(
- const void * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
- size_t s10, size_t s11, size_t s12,
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
- item_ct1.get_local_id(2)) *
- 2;
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
- item_ct1.get_local_id(0)) /
- ne12;
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
- item_ct1.get_local_id(0)) %
- ne12;
-
- if (i00 >= ne00) {
- return;
- }
-
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
-
- const int ib = i00/qk; // block index
- const int iqs = (i00%qk)/qr; // quant index
- const int iybs = i00 - i00%qk; // dst block start index
- const int y_offset = qr == 1 ? 1 : qk/2;
-
- // dequantize
- dfloat2 v;
- dequantize_kernel(src0_row, ib, iqs, v);
-
- dst_row[iybs + iqs + 0] = v.x();
- dst_row[iybs + iqs + y_offset] = v.y();
-}
-
-template<typename src0_t, typename dst_t>
-static void k_get_rows_float(
- const src0_t * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
- size_t s10, size_t s11, size_t s12,
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
- item_ct1.get_local_id(2);
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
- item_ct1.get_local_id(0)) /
- ne12;
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
- item_ct1.get_local_id(0)) %
- ne12;
-
- if (i00 >= ne00) {
- return;
- }
-
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
-
- dst_row[i00] = src0_row[i00];
-}
-
-static void mul_mat_p021_f16_f32(
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
- const sycl::nd_item<3> &item_ct1) {
-
- const sycl::half *x = (const sycl::half *)vx;
-
- const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
- item_ct1.get_local_id(0);
- const int channel_x = channel / (nchannels_y / nchannels_x);
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x;
- col_x0 += item_ct1.get_local_range(2)) {
- const int col_x = col_x0 + item_ct1.get_local_id(2);
-
- if (col_x >= ncols_x) {
- break;
- }
-
- // x is transposed and permuted
- const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
- const float xi =
- sycl::vec<sycl::half, 1>(x[ix])
- .convert<float, sycl::rounding_mode::automatic>()[0];
-
- const int row_y = col_x;
-
-
- // y is not transposed but permuted
- const int iy = channel*nrows_y + row_y;
-
- tmp += xi * y[iy];
- }
-
- // dst is not transposed and not permuted
- const int idst = channel*nrows_dst + row_dst;
-
- // sum up partial sums and write back result
-#pragma unroll
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
- tmp +=
- dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
- }
-
- if (item_ct1.get_local_id(2) == 0) {
- dst[idst] = tmp;
- }
-}
-
-static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
- const sycl::nd_item<3> &item_ct1) {
-
- const sycl::half *x = (const sycl::half *)vx;
-
- const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
- item_ct1.get_local_id(0);
- const int channel_x = channel / channel_x_divisor;
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- const int idst = channel*nrows_dst + row_dst;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x;
- col_x0 += item_ct1.get_local_range(2)) {
- const int col_x = col_x0 + item_ct1.get_local_id(2);
-
- if (col_x >= ncols_x) {
- break;
- }
-
- const int row_y = col_x;
-
- const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
- const int iy = channel*nrows_y + row_y;
-
- const float xi =
- sycl::vec<sycl::half, 1>(x[ix])
- .convert<float, sycl::rounding_mode::automatic>()[0];
-
- tmp += xi * y[iy];
- }
-
- // sum up partial sums and write back result
-#pragma unroll
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
- tmp +=
- dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
- }
-
- if (item_ct1.get_local_id(2) == 0) {
- dst[idst] = tmp;
- }
-}
-
-static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
- const float * xi = (const float *) cxi;
- float * dsti = (float *) cdsti;
-
- *dsti = *xi;
-}
-
-static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
- const float * xi = (const float *) cxi;
- sycl::half *dsti = (sycl::half *)cdsti;
-
- *dsti = sycl::vec<float, 1>(*xi)
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-}
-
-static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
- const sycl::half *xi = (const sycl::half *)cxi;
- sycl::half *dsti = (sycl::half *)cdsti;
-
- *dsti = *xi;
-}
-
-static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
- const sycl::half *xi = (const sycl::half *)cxi;
- float * dsti = (float *) cdsti;
-
- *dsti = *xi;
-}
-
-static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
- const int16_t *xi = (const int16_t *)cxi;
- int16_t *dsti = (int16_t *)cdsti;
-
- *dsti = *xi;
-}
-
-static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
- const int32_t *xi = (const int32_t *)cxi;
- int32_t *dsti = (int32_t *)cdsti;
-
- *dsti = *xi;
-}
-
-template <cpy_kernel_t cpy_1>
-static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2);
-
- if (i >= ne) {
- return;
- }
-
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
- // then combine those indices with the corresponding byte offsets to get the total offsets
- const int i03 = i/(ne00 * ne01 * ne02);
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
- const int i13 = i/(ne10 * ne11 * ne12);
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
-
- cpy_1(cx + x_offset, cdst + dst_offset);
-}
-
-static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
- const float * xi = (const float *) cxi;
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
-
- float amax = 0.0f; // absolute max
-
- for (int j = 0; j < QK8_0; j++) {
- const float v = xi[j];
- amax = sycl::fmax(amax, sycl::fabs((float)v));
- }
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dsti->d = d;
-
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = xi[j]*id;
-
- dsti->qs[j] = sycl::round((float)x0);
- }
-}
-
-static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
- const float * xi = (const float *) cxi;
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
-
- float amax = 0.0f;
- float vmax = 0.0f;
-
- for (int j = 0; j < QK4_0; ++j) {
- const float v = xi[j];
- if (amax < sycl::fabs((float)v)) {
- amax = sycl::fabs((float)v);
- vmax = v;
- }
- }
-
- const float d = vmax / -8;
- const float id = d ? 1.0f/d : 0.0f;
-
- dsti->d = d;
-
- for (int j = 0; j < QK4_0/2; ++j) {
- const float x0 = xi[0 + j]*id;
- const float x1 = xi[QK4_0/2 + j]*id;
-
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
-
- dsti->qs[j] = xi0;
- dsti->qs[j] |= xi1 << 4;
- }
-}
-
-static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
- const float * xi = (const float *) cxi;
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
-
- float vmin = FLT_MAX;
- float vmax = -FLT_MAX;
-
- for (int j = 0; j < QK4_1; ++j) {
- const float v = xi[j];
-
- if (v < vmin) vmin = v;
- if (v > vmax) vmax = v;
- }
-
- const float d = (vmax - vmin) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- dsti->dm.x() = d;
- dsti->dm.y() = vmin;
-
- for (int j = 0; j < QK4_1/2; ++j) {
- const float x0 = (xi[0 + j] - vmin)*id;
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
-
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
-
- dsti->qs[j] = xi0;
- dsti->qs[j] |= xi1 << 4;
- }
-}
-
-template <cpy_kernel_t cpy_blck, int qk>
-static void cpy_f32_q(const char * cx, char * cdst, const int ne,
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2)) *
- qk;
-
- if (i >= ne) {
- return;
- }
-
- const int i03 = i/(ne00 * ne01 * ne02);
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
- const int i13 = i/(ne10 * ne11 * ne12);
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
-
- cpy_blck(cx + x_offset, cdst + dst_offset);
-}
-
-static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
- const sycl::nd_item<3> &item_ct1) {
- const int row = item_ct1.get_group(1);
- const int col = item_ct1.get_local_id(2);
-
- float sum = 0.0f;
- for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
- sum += x[row * ncols + i];
- }
-
- sum = warp_reduce_sum(sum, item_ct1);
-
- if (col == 0) {
- dst[row] = sum;
- }
-}
-
-
-template<typename T>
-static inline void ggml_sycl_swap(T & a, T & b) {
- T tmp = a;
- a = b;
- b = tmp;
-}
-
-template <ggml_sort_order order>
-__dpct_inline__ static void
-k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
- const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
- // bitonic sort
- int col = item_ct1.get_local_id(2);
- int row = item_ct1.get_group(1);
-
- if (col >= ncols_pad) {
- return;
- }
-
- const float * x_row = x + row * ncols;
- auto dst_row = (int *)dpct_local;
-
- // initialize indices
- dst_row[col] = col;
-
- item_ct1.barrier(sycl::access::fence_space::local_space);
-
- for (int k = 2; k <= ncols_pad; k *= 2) {
- for (int j = k / 2; j > 0; j /= 2) {
- int ixj = col ^ j;
- if (ixj > col) {
- if ((col & k) == 0) {
- if (dst_row[col] >= ncols ||
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
- ) {
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
- }
- } else {
- if (dst_row[ixj] >= ncols ||
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
- ) {
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
- }
- }
- }
- /*
- DPCT1118:1: SYCL group functions and algorithms must be encountered
- in converged control flow. You may need to adjust the code.
- */
- item_ct1.barrier(sycl::access::fence_space::local_space);
- }
- }
-
- // copy the result to dst without the padding
- if (col < ncols) {
- dst[row * ncols + col] = dst_row[col];
- }
-}
-
-
-static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
- const sycl::nd_item<3> &item_ct1) {
- const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
- item_ct1.get_local_id(1);
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2);
-
- if (col >= ncols) {
- return;
- }
-
- const int i = row*ncols + col;
- //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
- //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
- dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
-}
-
-static void scale_f32(const float * x, float * dst, const float scale, const int k,
- const sycl::nd_item<3> &item_ct1) {
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2);
-
- if (i >= k) {
- return;
- }
-
- dst[i] = scale * x[i];
-}
-
-static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
- const sycl::nd_item<3> &item_ct1) {
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
- item_ct1.get_local_id(2);
-
- if (i >= k) {
- return;
- }
-
- dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
-}
-
-template <typename Ti, typename To>
-static void pool2d_nchw_kernel(
- const int ih, const int iw, const int oh, const int ow,
- const int kh, const int kw, const int sh, const int sw,
- const int ph, const int pw, const int parallel_elements,
- const Ti* src, To* dst, const enum ggml_op_pool op,
- const sycl::nd_item<3> &item_ct1) {
- int idx = item_ct1.get_local_id(2) +
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
- if (idx >= parallel_elements) {
- return;
- }
-
- const int I_HW = ih * iw;
- const int O_HW = oh * ow;
- const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / ow;
- const int cur_ow = idx % O_HW % ow;
- const Ti* i_ptr = src + nc * I_HW;
- To* o_ptr = dst + nc * O_HW;
- const int start_h = cur_oh * sh - ph;
- const int bh = sycl::max(0, start_h);
- const int eh = sycl::min(ih, start_h + kh);
- const int start_w = cur_ow * sw - pw;
- const int bw = sycl::max(0, start_w);
- const int ew = sycl::min(iw, start_w + kw);
-
- To res = 0;
-
- switch (op) {
- case GGML_OP_POOL_AVG: res = 0; break;
- case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
- }
-
- for (int i = bh; i < eh; i += 1) {
- for (int j = bw; j < ew; j += 1) {
-#if DPCT_COMPATIBILITY_TEMP >= 350
- /*
- DPCT1098:106: The '*' expression is used instead of the __ldg
- call. These two expressions do not provide the exact same
- functionality. Check the generated code for potential precision
- and/or performance issues.
- */
- Ti cur = *(i_ptr + i * iw + j);
-#else
- Ti cur = i_ptr[i * iw + j];
-#endif
- switch (op) {
- case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
- case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
- }
- }
- }
- o_ptr[cur_oh * ow + cur_ow] = res;
-}
-
-template <int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
- ggml_tensor *dst, const void *src0_dd,
- const int32_t *src1_dd, float *dst_dd,
- queue_ptr stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
- // strides in elements
- //const size_t s0 = nb0 / ggml_element_size(dst);
- const size_t s1 = nb1 / ggml_element_size(dst);
- const size_t s2 = nb2 / ggml_element_size(dst);
- const size_t s3 = nb3 / ggml_element_size(dst);
-
- const size_t s10 = nb10 / ggml_element_size(src1);
- const size_t s11 = nb11 / ggml_element_size(src1);
- const size_t s12 = nb12 / ggml_element_size(src1);
- //const size_t s13 = nb13 / ggml_element_size(src1);
-
- GGML_ASSERT(ne00 % 2 == 0);
-
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_get_rows<qk, qr, dq>(
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
- });
-
- (void) dst;
-}
-
-template <typename src0_t>
-static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const src0_t *src0_dd, const int32_t *src1_dd,
- float *dst_dd, queue_ptr stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
- // strides in elements
- //const size_t s0 = nb0 / ggml_element_size(dst);
- const size_t s1 = nb1 / ggml_element_size(dst);
- const size_t s2 = nb2 / ggml_element_size(dst);
- const size_t s3 = nb3 / ggml_element_size(dst);
-
- const size_t s10 = nb10 / ggml_element_size(src1);
- const size_t s11 = nb11 / ggml_element_size(src1);
- const size_t s12 = nb12 / ggml_element_size(src1);
- //const size_t s13 = nb13 / ggml_element_size(src1);
-
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
- });
- }
-
- (void) dst;
-}
-
-
-static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
- const int ky, const int kx_padded,
- queue_ptr stream) {
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
- const sycl::range<3> num_blocks(1, ky, block_num_x);
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
- static_assert(QK8_1 % WARP_SIZE == 0);
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(num_blocks * block_size, block_size),
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
- });
- }
-}
-
-static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
- float *dst, const int ncols_x,
- const int nrows_x,
- const int nchannels_x,
- const int nchannels_y,
- queue_ptr stream) {
-
- const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
- const sycl::range<3> block_dims(1, 1, WARP_SIZE);
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
- mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
- nchannels_y, item_ct1);
- });
- }
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_sycl(
- const void *vx, const float *y, float *dst, const int ncols_x,
- const int nrows_x, const int row_stride_x, const int nchannels_x,
- const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
-
- const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
- const sycl::range<3> block_dims(1, 1, WARP_SIZE);
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
- mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
- row_stride_x, channel_stride_x,
- nchannels_y / nchannels_x, item_ct1);
- });
- }
-}
-
-static void
-ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
- const int ne01, const int ne02, const int nb00,
- const int nb01, const int nb02, const int nb03,
- const int ne10, const int ne11, const int ne12,
- const int nb10, const int nb11, const int nb12,
- const int nb13, queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
- nb01, nb02, nb03, ne10, ne11, ne12,
- nb10, nb11, nb12, nb13, item_ct1);
- });
- }
-}
-
-static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
- }
-}
-
-static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
- }
-}
-
-static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- GGML_ASSERT(ne % QK8_0 == 0);
- const int num_blocks = ne / QK8_0;
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
- sycl::range<3>(1, 1, 1)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
-}
-
-static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- GGML_ASSERT(ne % QK4_0 == 0);
- const int num_blocks = ne / QK4_0;
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
- sycl::range<3>(1, 1, 1)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
-}
-
-static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- GGML_ASSERT(ne % QK4_1 == 0);
- const int num_blocks = ne / QK4_1;
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
- sycl::range<3>(1, 1, 1)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
-}
-
-static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- dpct::has_capability_or_fail(stream->get_device(),
- {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
- }
-}
-
-static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- // dpct::has_capability_or_fail(stream->get_device(),
- // {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
- }
-}
-
-static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
- const int ne00, const int ne01,
- const int ne02, const int nb00,
- const int nb01, const int nb02,
- const int nb03, const int ne10,
- const int ne11, const int ne12,
- const int nb10, const int nb11,
- const int nb12, const int nb13,
- queue_ptr stream) {
-
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
- {
- // dpct::has_capability_or_fail(stream->get_device(),
- // {sycl::aspect::fp16});
-
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
- item_ct1);
- });
- }
-}
-
-static void scale_f32_sycl(const float *x, float *dst, const float scale,
- const int k, queue_ptr stream) {
- const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- scale_f32(x, dst, scale, k, item_ct1);
- });
-}
-
-static void clamp_f32_sycl(const float *x, float *dst, const float min,
- const float max, const int k,
- queue_ptr stream) {
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
- stream->parallel_for(
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- clamp_f32(x, dst, min, max, k, item_ct1);
- });
-}
-
-static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
- const int nrows, queue_ptr stream) {
- const sycl::range<3> block_dims(1, 1, WARP_SIZE);
- const sycl::range<3> block_nums(1, nrows, 1);
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1)
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
- k_sum_rows_f32(x, dst, ncols, item_ct1);
- });
-}
-
-static int next_power_of_2(int x) {
- int n = 1;
- while (n < x) {
- n *= 2;
- }
- return n;
-}
-
-static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
- const int nrows, ggml_sort_order order,
- queue_ptr stream) {
- // bitonic sort requires ncols to be power of 2
- const int ncols_pad = next_power_of_2(ncols);
-
- const sycl::range<3> block_dims(1, 1, ncols_pad);
- const sycl::range<3> block_nums(1, nrows, 1);
- const size_t shared_mem = ncols_pad * sizeof(int);
-
- if (order == GGML_SORT_ORDER_ASC) {
- stream->submit([&](sycl::handler &cgh) {
- sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
- sycl::range<1>(shared_mem), cgh);
-
- cgh.parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
- x, dst, ncols, ncols_pad, item_ct1,
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
- .get());
- });
- });
- } else if (order == GGML_SORT_ORDER_DESC) {
- stream->submit([&](sycl::handler &cgh) {
- sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
- sycl::range<1>(shared_mem), cgh);
-
- cgh.parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
- x, dst, ncols, ncols_pad, item_ct1,
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
- .get());
- });
- });
- } else {
- GGML_ABORT("fatal error");
- }
-}
-
-static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
- const int nrows, queue_ptr stream) {
- const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
- const sycl::range<3> block_nums(1, nrows, 1);
- const size_t shared_mem = 256 * sizeof(float);
-
- stream->submit([&](sycl::handler &cgh) {
- sycl::local_accessor<float, 1> shared_data(
- sycl::range<1>(shared_mem/sizeof(float)), cgh);
- sycl::local_accessor<int, 1> shared_indices(
- sycl::range<1>(shared_mem/sizeof(float)), cgh);
-
- cgh.parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- const int tid = item_ct1.get_local_id(2);
- const int row = item_ct1.get_global_id(1);
-
- float max_val = -INFINITY;
- int max_idx = -1;
-
- for (int col = tid; col < ncols; col += 256) {
- float val = x[row * ncols + col];
- if (val > max_val) {
- max_val = val;
- max_idx = col;
- }
- }
-
- shared_data[tid] = max_val;
- shared_indices[tid] = max_idx;
- item_ct1.barrier(sycl::access::fence_space::local_space);
-
- for (int stride = 256/2; stride > 0; stride >>= 1) {
- if (tid < stride) {
- float val1 = shared_data[tid];
- float val2 = shared_data[tid + stride];
- if (val2 > val1) {
- shared_data[tid] = val2;
- shared_indices[tid] = shared_indices[tid + stride];
- }
- }
- item_ct1.barrier(sycl::access::fence_space::local_space);
- }
-
-
- if (tid == 0) {
- dst[row] = shared_indices[0];
- }
- });
- });
-}
-static void diag_mask_inf_f32_sycl(const float *x, float *dst,
- const int ncols_x, const int nrows_x,
- const int rows_per_channel, const int n_past,
- queue_ptr stream) {
- const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
- const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
- const sycl::range<3> block_nums(1, block_num_x, nrows_x);
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- diag_mask_inf_f32(x, dst, ncols_x,
- rows_per_channel, n_past,
- item_ct1);
- });
-}
-
-static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
- const struct ggml_tensor *src,
- int64_t i3, int64_t i2,
- int64_t i1_low, int64_t i1_high,
- queue_ptr stream) try {
-
- dpct::memcpy_direction kind;
- char * src_ptr;
- if (src->backend == GGML_BACKEND_TYPE_CPU) {
- kind = dpct::host_to_device;
- src_ptr = (char *) src->data;
- // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
- } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
- GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
- kind = dpct::device_to_device;
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
- int id;
- SYCL_CHECK(CHECK_TRY_ERROR(
- id = get_current_device_id()));
- // GGML_SYCL_DEBUG("current device index %d\n", id);
- src_ptr = (char *) extra->data_device[id];
- } else {
- // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
- GGML_ABORT("fatal error");
- }
- char * dst_ptr = (char *) dst;
-
- GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
- GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
- const enum ggml_type type = src->type;
- const int64_t ts = ggml_type_size(type);
- const int64_t bs = ggml_blck_size(type);
- int64_t i1_diff = i1_high - i1_low;
-
- const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*ne0/bs) {
- // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
- // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
- return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
- kind, *stream));
-
- } else if (nb0 == ts) {
- return CHECK_TRY_ERROR(
- dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
- ts * ne0 / bs, i1_diff, kind, *stream));
- } else {
- for (int64_t i1 = 0; i1 < i1_diff; i1++) {
- const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
- // pretend the row is a matrix with cols=1
- dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
- rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
- /*
- DPCT1001:85: The statement could not be removed.
- */
- /*
- DPCT1000:86: Error handling if-stmt was detected but could not be
- rewritten.
- */
- if (r != 0) return r;
- }
- return 0;
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_d, const float *src1_d,
- float *dst_d, const queue_ptr &stream) {
-
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
- const int32_t * src1_i32 = (const int32_t *) src1_d;
-
- switch (src0->type) {
- case GGML_TYPE_F16:
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
- src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_F32:
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_Q4_0:
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_Q4_1:
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_Q5_0:
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_Q5_1:
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- case GGML_TYPE_Q8_0:
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
- break;
- default:
- // TODO: k-quants
- fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
- GGML_ABORT("fatal error");
- break;
- }
-}
-
-
-static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_d, const float *src1_d,
- float *dst_d,
- const queue_ptr &main_stream) {
-
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
-
- (void) src1;
- (void) src1_d;
-}
-
-
-inline void ggml_sycl_op_mul_mat_sycl(
- ggml_backend_sycl_context & ctx,
- const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
- const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
- float *dst_dd_i, const int64_t row_low, const int64_t row_high,
- const int64_t src1_ncols, const int64_t src1_padded_row_size,
- const queue_ptr &stream) try {
-
- GGML_ASSERT(src0_dd_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_dd_i != nullptr);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne10 = src1->ne[0];
-
- const int64_t ne0 = dst->ne[0];
-
- const int64_t row_diff = row_high - row_low;
-
- int id;
- SYCL_CHECK(
- CHECK_TRY_ERROR(id = get_current_device_id()));
-
- // the main device has a larger memory buffer to hold the results from all GPUs
- // ldc == nrows of the matrix that cuBLAS writes into
- int ldc = id == ctx.device ? ne0 : row_diff;
-
-#ifdef GGML_SYCL_F16
- bool use_fp16 = true; // TODO(Yu) SYCL capability check
-#else
- bool use_fp16 = false;
-#endif
- if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
- use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
- dst->op_params[0] == GGML_PREC_DEFAULT) {
-
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
- ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
- if (src0->type != GGML_TYPE_F16) {
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
- GGML_ASSERT(to_fp16_sycl != nullptr);
- size_t ne = row_diff*ne00;
- src0_as_f16.alloc(ne);
- to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
- }
- const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
- ? (const sycl::half *)src0_dd_i
- : src0_as_f16.get();
-
- ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
- if (src1->type != GGML_TYPE_F16) {
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
- GGML_ASSERT(to_fp16_sycl != nullptr);
- size_t ne = src1_ncols*ne10;
- src1_as_f16.alloc(ne);
- to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
- }
- const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
- ? (const sycl::half *)src1->data + src1_padded_row_size
- : src1_as_f16.get();
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
-
- const sycl::half alpha_f16 = 1.0f;
- const sycl::half beta_f16 = 0.0f;
-#if !GGML_SYCL_DNNL
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
- *stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
- dpct::library_data_t::real_half)));
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-#else
- auto dnnl_stream = ctx.stream_dnnl(stream);
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
- src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
-#endif
- }
- else {
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
- ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
- ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
- if (src0->type != GGML_TYPE_F32) {
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
- GGML_ASSERT(to_fp32_sycl != nullptr);
- src0_ddq_as_f32.alloc(row_diff*ne00);
- to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
- }
- if (src1->type != GGML_TYPE_F32) {
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
- GGML_ASSERT(to_fp32_sycl != nullptr);
- src1_ddq_as_f32.alloc(src1_ncols*ne10);
- to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
- }
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
- const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
-
- const float alpha = 1.0f;
- const float beta = 0.0f;
-#if !GGML_SYCL_DNNL
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
- *stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
- src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
- dst_dd_i, ldc)));
-#else
- auto dnnl_stream = ctx.stream_dnnl(stream);
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
- src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
-#endif
- }
- (void) dst;
- (void) src1_ddq_i;
- (void) src1_padded_row_size;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_dd, const float *src1_dd,
- float *dst_dd, const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
- const int k0 = opts[1];
- const int k1 = opts[2];
- const int s0 = opts[3];
- const int s1 = opts[4];
- const int p0 = opts[5];
- const int p1 = opts[6];
-
- const int64_t IH = src0->ne[1];
- const int64_t IW = src0->ne[0];
-
- const int64_t N = dst->ne[3];
- const int64_t OC = dst->ne[2];
- const int64_t OH = dst->ne[1];
- const int64_t OW = dst->ne[0];
-
- const int parallel_elements = N * OC * OH * OW;
- const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
- sycl::range<3> block_nums(1, 1, num_blocks);
- main_stream->parallel_for(
- sycl::nd_range<3>(block_nums *
- sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
- sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
- [=](sycl::nd_item<3> item_ct1) {
- pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
- parallel_elements, src0_dd, dst_dd, op,
- item_ct1);
- });
-
- (void) src1;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_dd, const float *src1_dd,
- float *dst_dd,
- const queue_ptr &main_stream) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- const int64_t ne = ggml_nelements(src0);
-
- sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_dd, const float *src1_dd,
- float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- const int64_t ncols = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
-
- sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_dd, const float *src1_dd,
- float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
- const int64_t ncols = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
-
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
- argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- const float *src0_dd, const float *src1_dd,
- float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
- const int64_t ncols = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
-
- argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1,
- ggml_tensor *dst, const float *src0_dd,
- const float *src1_dd, float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int nrows0 = ggml_nrows(src0);
-
- const int n_past = ((int32_t *) dst->op_params)[0];
-
- diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
- ggml_tensor *dst, const float *src0_dd,
- const float *src1_dd, float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- float scale;
- memcpy(&scale, dst->op_params, sizeof(float));
-
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
- /*
- DPCT1010:87: SYCL uses exceptions to report errors and does not use the
- error codes. The call was replaced with 0. You need to rewrite this code.
- */
- SYCL_CHECK(0);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
- ggml_tensor *dst, const float *src0_dd,
- const float *src1_dd, float *dst_dd,
- const queue_ptr &main_stream) {
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- float min;
- float max;
- memcpy(&min, dst->op_params, sizeof(float));
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
- clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
- /*
- DPCT1010:88: SYCL uses exceptions to report errors and does not use the
- error codes. The call was replaced with 0. You need to rewrite this code.
- */
- SYCL_CHECK(0);
-
- (void) src1;
- (void) dst;
- (void) src1_dd;
-}
-
-static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
- static bool peer_access_enabled = false;
-
- const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
-
- if (peer_access_enabled == enable_peer_access) {
- return;
- }
-
-#ifdef NDEBUG
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- SYCL_CHECK(ggml_sycl_set_device(i));
- }
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- SYCL_CHECK(ggml_sycl_set_device(i));
-
- for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
- if (i == id_other) {
- continue;
- }
- if (i != main_device && id_other != main_device) {
- continue;
- }
-
- // int can_access_peer;
- // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
- // if (can_access_peer) {
- // if (enable_peer_access) {
- // SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
- // } else {
- // SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
- // }
- // }
- }
- }
-#endif // NDEBUG
-
- peer_access_enabled = enable_peer_access;
-}
-
-static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1, ggml_tensor *dst,
- ggml_sycl_op_mul_mat_t op,
- const bool convert_src1_to_q8_1) try {
-
- GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-
- GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
- const int64_t nrows1 = ggml_nrows(src1);
-
- GGML_ASSERT(ne03 == ne13);
-
- const int64_t ne0 = dst->ne[0];
- const int64_t ne1 = dst->ne[1];
-
- const int nb2 = dst->nb[2];
- const int nb3 = dst->nb[3];
-
- GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
-
- GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
-
- const int64_t i02_divisor = ne12 / ne02;
-
- const size_t src0_ts = ggml_type_size(src0->type);
- const size_t src0_bs = ggml_blck_size(src0->type);
- const size_t q8_1_ts = sizeof(block_q8_1);
- const size_t q8_1_bs = QK8_1;
-
- ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
- ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
-
- const bool src0_is_contiguous = ggml_is_contiguous(src0);
- const bool src1_is_contiguous = ggml_is_contiguous(src1);
-
- int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
-
- const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
- GGML_ASSERT(!(split && ne02 > 1));
- GGML_ASSERT(!(split && ne03 > 1));
- GGML_ASSERT(!(split && ne02 < ne12));
-
- std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
- if (split) {
- // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
- // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
- tensor_split = buft_ctx->tensor_split;
- }
-
- struct dev_data {
- ggml_sycl_pool_alloc<char> src0_dd_alloc;
- ggml_sycl_pool_alloc<float> src1_ddf_alloc;
- ggml_sycl_pool_alloc<char> src1_ddq_alloc;
- ggml_sycl_pool_alloc<float> dst_dd_alloc;
-
- char *src0_dd = nullptr;
- float *src1_ddf = nullptr; // float
- char *src1_ddq = nullptr; // q8_1
- float *dst_dd = nullptr;
-
- int64_t row_low;
- int64_t row_high;
- };
-
- dev_data dev[GGML_SYCL_MAX_DEVICES];
-
- int used_devices = 0;
- queue_ptr main_stream = ctx.stream();
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- // by default, use all rows
- dev[i].row_low = 0;
- dev[i].row_high = ne01;
-
- // for multi GPU, get the row boundaries from tensor split
- // and round to mul_mat_q tile sizes
- if (split) {
- const int64_t rounding = get_row_rounding(src0->type, tensor_split);
-
- if (i != 0) {
- dev[i].row_low = ne01*tensor_split[i];
- if (dev[i].row_low < ne01) {
- dev[i].row_low -= dev[i].row_low % rounding;
- }
- }
-
- if (i != ggml_sycl_info().device_count - 1) {
- dev[i].row_high = ne01*tensor_split[i + 1];
- if (dev[i].row_high < ne01) {
- dev[i].row_high -= dev[i].row_high % rounding;
- }
- }
- }
- }
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
- continue;
- }
-
- used_devices++;
-
- const bool src1_on_device = i == ctx.device;
- const bool dst_on_device = i == ctx.device;
-
- ggml_sycl_set_device(i);
- queue_ptr stream = ctx.stream(i, 0);
-
- if (src0_is_contiguous) {
- dev[i].src0_dd = (char *) src0->data;
- } else {
- dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
- }
-
- if (src1_on_device && src1_is_contiguous) {
- dev[i].src1_ddf = (float *) src1->data;
- } else {
- dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
- }
-
- if (convert_src1_to_q8_1) {
- dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
-
- if (src1_on_device && src1_is_contiguous) {
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
- /*
- DPCT1010:90: SYCL uses exceptions to report errors and does not
- use the error codes. The call was replaced with 0. You need to
- rewrite this code.
- */
- SYCL_CHECK(0);
- }
- }
-
- if (dst_on_device) {
- dev[i].dst_dd = (float *) dst->data;
- } else {
- const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
- dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
- }
- }
-
- // if multiple devices are used they need to wait for the main device
- // here an event is recorded that signals that the main device has finished calculating the input data
- if (split && used_devices > 1) {
- ggml_sycl_set_device(ctx.device);
- /*
- DPCT1024:91: The original code returned the error code that was further
- consumed by the program logic. This original code was replaced with 0.
- You may need to rewrite the program logic consuming the error code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- *src0_extra->events[ctx.device][0] =
- ctx.stream()->ext_oneapi_submit_barrier()));
- }
-
- const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
- for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
- const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
- const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
- continue;
- }
-
- const bool src1_on_device = i == ctx.device;
- const bool dst_on_device = i == ctx.device;
- const int64_t row_diff = dev[i].row_high - dev[i].row_low;
-
- ggml_sycl_set_device(i);
- queue_ptr stream = ctx.stream(i, is);
-
- // wait for main GPU data if necessary
- if (split && (i != ctx.device || is != 0)) {
- /*
- DPCT1009:163: SYCL uses exceptions to report errors and does not
- use the error codes. The original code was commented out and a
- warning string was inserted. You need to rewrite this code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
- {*src0_extra->events[ctx.device][0]})));
- }
-
- for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
- const int64_t i03 = i0 / ne12;
- const int64_t i02 = i0 % ne12;
-
- const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
-
- // for split tensors the data begins at i0 == i0_offset_low
- char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
- float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
- char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset;
- float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
-
- // the main device memory buffer can be on VRAM scratch, with space for all partial results
- // in that case an offset on dst_ddf_i is needed
- if (i == ctx.device) {
- dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
- }
-
- // copy src0, src1 to device if necessary
- if (src1_is_contiguous) {
- if (i != ctx.device) {
- if (convert_src1_to_q8_1) {
- char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
- src1_ddq_i, src1_ddq_i_source,
- src1_ncols * src1_padded_col_size * q8_1_ts /
- q8_1_bs).wait()));
- } else {
-
- float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
- src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
-
- SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
- src1_ddf_i, src1_ddf_i_source,
- src1_ncols * ne10 * sizeof(float))));
- }
- }
- } else if (src1_on_device && !src1_is_contiguous) {
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
- src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
- } else {
- GGML_ABORT("fatal error");
- }
-
- if (convert_src1_to_q8_1 && !src1_is_contiguous) {
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
- /*
- DPCT1010:92: SYCL uses exceptions to report errors and does
- not use the error codes. The call was replaced with 0. You
- need to rewrite this code.
- */
- SYCL_CHECK(0);
- }
-
- if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
- }
- if (src1->type == GGML_TYPE_F16) {
- src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
- }
- // do the computation
- SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
- dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
- /*
- DPCT1010:93: SYCL uses exceptions to report errors and does not
- use the error codes. The call was replaced with 0. You need to
- rewrite this code.
- */
- SYCL_CHECK(0);
-
- // copy dst to host or other device if necessary
- if (!dst_on_device) {
- void * dst_off_device = dst->data;
- if (split) {
- // src0 = weight matrix is saved as a transposed matrix for better memory layout.
- // dst is NOT transposed.
- // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
- // Instead they need to be copied to the correct slice in ne0 = dst row index.
- // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
- float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
- GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
- dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
-
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
- dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
- row_diff * sizeof(float), row_diff * sizeof(float),
- src1_ncols, dpct::device_to_device, *stream)));
- } else {
- float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
- GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
- dhf_dst_i += src1_col_0*ne0;
- SYCL_CHECK(CHECK_TRY_ERROR(
- stream->memcpy(dhf_dst_i, dst_dd_i,
- src1_ncols * ne0 * sizeof(float)).wait()));
- }
- }
-
- // add event for the main device to wait on until other device is done
- if (split && (i != ctx.device || is != 0)) {
- /*
- DPCT1024:94: The original code returned the error code that
- was further consumed by the program logic. This original
- code was replaced with 0. You may need to rewrite the
- program logic consuming the error code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- *src0_extra->events[i][is] =
- stream->ext_oneapi_submit_barrier()));
- }
- }
- }
- }
-
- // main device waits for all other devices to be finished
- if (split && ggml_sycl_info().device_count > 1) {
- int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
- is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
-
- ggml_sycl_set_device(ctx.device);
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
- if (dev[i].row_low == dev[i].row_high) {
- continue;
- }
- for (int64_t is = 0; is < is_max; ++is) {
- SYCL_CHECK(CHECK_TRY_ERROR(
- ctx.stream()->ext_oneapi_submit_barrier(
- {*src0_extra->events[i][is]})));
- }
- }
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-
-static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_SYCL_DEBUG("call %s\n", __func__);
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
- GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_SYCL_DEBUG("call %s\n", __func__);
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
- GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_SYCL_DEBUG("call %s\n", __func__);
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
- GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_SYCL_DEBUG("call %s\n", __func__);
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
- GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_SYCL_DEBUG("call %s\n", __func__);
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
- GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1,
- ggml_tensor *dst) try {
- GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
- GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t ne12 = src1->ne[2];
-
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
- queue_ptr main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1,
- ggml_tensor *dst) try {
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(!ggml_is_permuted(src0));
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t nb01 = src0->nb[1];
- const int64_t nb02 = src0->nb[2];
-
- const int64_t ne12 = src1->ne[2];
-
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
- queue_ptr main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- const int64_t row_stride_x = nb01 / sizeof(sycl::half);
- const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
-
- ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
- const sycl::half *src1_as_f16, char *dst,
- const void **ptrs_src, void **ptrs_dst,
- int64_t ne12, int64_t ne13, int64_t ne23,
- size_t nb02, size_t nb03, size_t nb12,
- size_t nb13, size_t nbd2, size_t nbd3,
- int64_t r2, int64_t r3,
- const sycl::nd_item<3> &item_ct1) {
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
- item_ct1.get_local_id(2);
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
- item_ct1.get_local_id(1);
-
- if (i13 >= ne13 || i12 >= ne12) {
- return;
- }
-
- int64_t i03 = i13 / r3;
- int64_t i02 = i12 / r2;
-
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
-}
-
-static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
- const ggml_tensor *src0,
- const ggml_tensor *src1,
- ggml_tensor *dst) try {
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const int64_t ne_dst = ggml_nelements(dst);
-
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
- queue_ptr main_stream = ctx.stream();;
-
- void * src0_ddq = src0->data;
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- // convert src1 to fp16
- ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
- if (src1->type != GGML_TYPE_F16) {
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
- const int64_t ne_src1 = ggml_nelements(src1);
- src1_f16_alloc.alloc(ne_src1);
- GGML_ASSERT(to_fp16_sycl != nullptr);
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
- }
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
- : src1_f16_alloc.get();
-
- char * dst_t;
-
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
-
- // dst strides
- size_t nbd2 = dst->nb[2];
- size_t nbd3 = dst->nb[3];
-
- const float alpha_f32 = 1.0f;
- const float beta_f32 = 0.0f;
-
- const void * alpha = &alpha_f32;
- const void * beta = &beta_f32;
-
- dst_t = (char *) dst_ddf;
-
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- // broadcast factors
- const int64_t r2 = ne12/ne02;
- const int64_t r3 = ne13/ne03;
-
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
- *main_stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
- (const char *)src0_as_f16, dpct::library_data_t::real_half,
- nb01 / nb00, nb02 / nb00,
- (const char *)src1_f16, dpct::library_data_t::real_half,
- nb11 / nb10, nb12 / nb10, beta,
- (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
- ne12 * ne13, cu_compute_type)));
- } else {
- const int ne23 = ne12*ne13;
-
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
-
- sycl::range<3> block_dims(1, ne12, ne13);
- /*
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
- the limit. To get the device limit, query
- info::device::max_work_group_size. Adjust the work-group size if needed.
- */
- {
- dpct::has_capability_or_fail(main_stream->get_device(),
- {sycl::aspect::fp16});
-
- main_stream->submit([&](sycl::handler &cgh) {
- const void **ptrs_src_get = ptrs_src.get();
- void **ptrs_dst_get = ptrs_dst.get();
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_compute_batched_ptrs(
- src0_as_f16, src1_f16,
- dst_t, ptrs_src_get,
- ptrs_dst_get, ne12, ne13, ne23,
- nb02, nb03, nb12_scaled, nb13_scaled,
- nbd2, nbd3, r2, r3, item_ct1);
- });
- });
- }
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
- *main_stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
- (const void **)(ptrs_src.get() + 0 * ne23),
- dpct::library_data_t::real_half, nb01 / nb00,
- (const void **)(ptrs_src.get() + 1 * ne23),
- dpct::library_data_t::real_half, nb11 / nb10, beta,
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
- cu_compute_type)));
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
- // TODO: accuracy issues in MMQ
- return false;
-}
-
-bool ggml_sycl_supports_dmmv(enum ggml_type type) {
- switch (type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_F16:
- return true;
- default:
- return false;
- }
-}
-
-static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
- int64_t min_compute_capability = INT_MAX;
-
- if (split) {
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
- auto & tensor_split = buft_ctx->tensor_split;
- for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
- // skip devices that are not going to do any work:
- if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
- continue;
- }
-
- if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
- min_compute_capability = ggml_sycl_info().devices[id].cc;
- }
- }
- } else {
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
- }
-
- // check data types and tensor shapes for custom matrix multiplication kernels:
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
-
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-
- bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
- // mmvq and mmq need the __dp4a instruction which is available for gen12+
- // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
- use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
-#ifdef SYCL_USE_XMX
- use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // SYCL_USE_XMX
-
- // mmvq path is faster in the CUDA backend.
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
- use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-
- if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
- // KQ single-batch
- ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
- // KQV single-batch
- ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
- // KQ + KQV multi-batch
- ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
- } else if (use_dequantize_mul_mat_vec) {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
- } else if (use_mul_mat_vec_q) {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
- } else if (use_mul_mat_q) {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
- } else {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
- }
-}
-
-
-struct mmid_row_mapping {
- int32_t i1;
- int32_t i2;
-};
-
-__dpct_inline__ static void k_copy_src1_to_contiguous(
- const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
- int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
- const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
- int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
- const sycl::nd_item<3> &item_ct1, int &src1_row) {
- int32_t iid1 = item_ct1.get_group(2);
- int32_t id = item_ct1.get_group(1);
-
- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
-
- if (row_id_i != i02) {
- return;
- }
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = iid1;
-
- if (item_ct1.get_local_id(2) == 0) {
- src1_row =
- dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
- cur_src1_row, 1);
- row_mapping[src1_row] = {id, iid1};
- }
- /*
- DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
- performance if there is no access to global memory.
- */
- item_ct1.barrier();
-
- const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
- float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
-
-#pragma unroll
- for (int i = item_ct1.get_local_id(2); i < ne10;
- i += item_ct1.get_local_range(2)) {
- src1_row_contiguous[i] = src1_row_original[i];
- }
-}
-
-__dpct_inline__ static void k_copy_dst_from_contiguous(
- char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
- const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
- size_t nb2, const sycl::nd_item<3> &item_ct1) {
- int32_t i = item_ct1.get_group(2);
-
- const int32_t i1 = row_mapping[i].i1;
- const int32_t i2 = row_mapping[i].i2;
-
- const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
- float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
-
-#pragma unroll
- for (int j = item_ct1.get_local_id(2); j < ne0;
- j += item_ct1.get_local_range(2)) {
- dst_row_original[j] = dst_row_contiguous[j];
- }
-}
-
-static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
- const ggml_tensor *src1,
- ggml_tensor *dst) try {
- GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
-
- const ggml_tensor *ids = dst->src[2];
- GGML_TENSOR_BINARY_OP_LOCALS
-
- const queue_ptr stream = ctx.stream();
-
- const int64_t n_as = ne02;
- const int64_t n_ids = ids->ne[0];
-
- std::vector<char> ids_host(ggml_nbytes(ids));
- const char * ids_dev = (const char *) ids->data;
-
- SYCL_CHECK(CHECK_TRY_ERROR(
- stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
-
- ggml_tensor src0_row = *src0;
- ggml_tensor src1_row = *src1;
- ggml_tensor dst_row = *dst;
-
- char *src0_original = (char *)src0->data;
- char *src1_original = (char *)src1->data;
- char *dst_original = (char *)dst->data;
-
- src0_row.ne[2] = 1;
- src0_row.ne[3] = 1;
- src0_row.nb[3] = nb02;
-
- src1_row.ne[1] = 1;
- src1_row.ne[2] = 1;
- src1_row.ne[3] = 1;
- src1_row.nb[2] = nb11;
- src1_row.nb[3] = nb11;
-
- dst_row.ne[1] = 1;
- dst_row.ne[2] = 1;
- dst_row.ne[3] = 1;
- dst_row.nb[2] = nb1;
- dst_row.nb[3] = nb1;
- if (ne12 == 1) {
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
- const int64_t i11 = id % ne11;
- const int64_t i12 = iid1;
-
- const int64_t i1 = id;
- const int64_t i2 = i12;
-
- src0_row.data = src0_original + i02*nb02;
- src1_row.data = src1_original + + i11*nb11 + i12*nb12;
- dst_row.data = dst_original + i1*nb1 + i2*nb2;
-
- ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
- }
- }
- } else {
- ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
- ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
-
- src1_row.data = src1_contiguous.get();
- dst_row.data = dst_contiguous.get();
-
- for (int64_t i02 = 0; i02 < n_as; i02++) {
- int64_t num_src1_rows = 0;
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
-
- if (row_id_i != i02) {
- continue;
- }
-
- num_src1_rows++;
- }
- }
-
- if (num_src1_rows == 0) {
- continue;
- }
-
-
- ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
- ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
- SYCL_CHECK(CHECK_TRY_ERROR(
- stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
-
- {
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
- sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
- stream->submit([&](sycl::handler &cgh) {
- sycl::local_accessor<int, 0> src1_row_acc(cgh);
-
- char *__restrict src1_contiguous_get =
- src1_contiguous.get();
- int *__restrict dev_cur_src1_row_get =
- dev_cur_src1_row.get();
- mmid_row_mapping *__restrict dev_row_mapping_get =
- dev_row_mapping.get();
- size_t ids_nb_ct6 = ids->nb[1];
- size_t ids_nb_ct7 = ids->nb[0];
-
- cgh.parallel_for(
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_copy_src1_to_contiguous(
- src1_original, src1_contiguous_get,
- dev_cur_src1_row_get,
- dev_row_mapping_get, ids_dev, i02,
- ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
- item_ct1, src1_row_acc);
- });
- });
- }
-
- src0_row.data = src0_original + i02*nb02;
-
- GGML_ASSERT(nb11 == sizeof(float)*ne10);
- GGML_ASSERT(nb1 == sizeof(float)*ne0);
- src1_row.ne[1] = num_src1_rows;
-
- src1_row.nb[1] = nb11;
- src1_row.nb[2] = num_src1_rows*nb11;
- src1_row.nb[3] = num_src1_rows*nb11;
-
- dst_row.ne[1] = num_src1_rows;
- dst_row.nb[1] = nb1;
- dst_row.nb[2] = num_src1_rows*nb1;
- dst_row.nb[3] = num_src1_rows*nb1;
-
- ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-
- {
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
- sycl::range<3> grid_dims(1, 1, num_src1_rows);
- stream->submit([&](sycl::handler &cgh) {
- const char *__restrict dst_contiguous_get =
- dst_contiguous.get();
- const mmid_row_mapping *__restrict dev_row_mapping_get =
- dev_row_mapping.get();
-
- cgh.parallel_for(
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1) {
- k_copy_dst_from_contiguous(dst_original,
- dst_contiguous_get,
- dev_row_mapping_get,
- ne0, nb1, nb2, item_ct1);
- });
- });
- }
- }
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
-}
-
-static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
-}
-
-static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
- ggml_tensor *dst) try {
- const int64_t ne = ggml_nelements(src0);
- GGML_ASSERT(ne == ggml_nelements(src1));
-
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
-
- GGML_TENSOR_BINARY_OP_LOCALS01;
-
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
- queue_ptr main_stream = ctx.stream();
-
- char * src0_ddc = (char *) src0->data;
- char * src1_ddc = (char *) src1->data;
-
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
- } else {
- fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
- ggml_type_name(src0->type), ggml_type_name(src1->type));
- GGML_ABORT("fatal error");
- }
-
- (void) dst;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- // TODO: why do we pass dst as src1 here?
- ggml_sycl_cpy(ctx, src0, dst, nullptr);
- (void) src1;
-}
-
-static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
-}
-
-static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
-}
-
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
-}
-
-static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
-}
-
-static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
-}
-
-static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
-}
-
-static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
-}
-
-static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
-}
-
-static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
-}
-
-static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- (void) src0;
- (void) src1;
- (void) dst;
-}
-
-void ggml_sycl_set_main_device(const int main_device) try {
- if (dpct::get_current_device_id() == main_device) return;
- check_allow_gpu_index(main_device);
- dpct::select_device(main_device);
-
- if (g_ggml_sycl_debug) {
- dpct::device_info prop;
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
- prop, dpct::dev_mgr::instance().get_device(main_device))));
- fprintf(stderr, "Using device %d (%s) as main device\n",
- main_device, prop.get_name());
- }
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
- if (!g_sycl_loaded) return false;
-
- ggml_sycl_func_t func;
-
- switch (tensor->op) {
- case GGML_OP_ARGMAX:
- func = ggml_sycl_argmax;
- break;
- case GGML_OP_CONV_TRANSPOSE_1D:
- func = ggml_sycl_op_conv_transpose_1d;
- break;
- case GGML_OP_REPEAT:
- func = ggml_sycl_repeat;
- break;
- case GGML_OP_GET_ROWS:
- func = ggml_sycl_get_rows;
- break;
- case GGML_OP_DUP:
- func = ggml_sycl_dup;
- break;
- case GGML_OP_ADD:
- case GGML_OP_ADD1: // TODO: more efficient implementation
- func = ggml_sycl_add;
- break;
- case GGML_OP_SUB:
- func = ggml_sycl_sub;
- break;
- case GGML_OP_ACC:
- func = ggml_sycl_acc;
- break;
- case GGML_OP_MUL:
- func = ggml_sycl_mul;
- break;
- case GGML_OP_LOG:
- func = ggml_sycl_log;
- break;
- case GGML_OP_DIV:
- func = ggml_sycl_div;
- break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(tensor)) {
- case GGML_UNARY_OP_NEG:
- func = ggml_sycl_neg;
- break;
- case GGML_UNARY_OP_STEP:
- func = ggml_sycl_step;
- break;
- case GGML_UNARY_OP_GELU:
- func = ggml_sycl_gelu;
- break;
- case GGML_UNARY_OP_SILU:
- func = ggml_sycl_silu;
- break;
- case GGML_UNARY_OP_GELU_QUICK:
- func = ggml_sycl_gelu_quick;
- break;
- case GGML_UNARY_OP_TANH:
- func = ggml_sycl_tanh;
- break;
- case GGML_UNARY_OP_RELU:
- func = ggml_sycl_relu;
- break;
- case GGML_UNARY_OP_SIGMOID:
- func = ggml_sycl_sigmoid;
- break;
- case GGML_UNARY_OP_HARDSIGMOID:
- func = ggml_sycl_hardsigmoid;
- break;
- case GGML_UNARY_OP_HARDSWISH:
- func = ggml_sycl_hardswish;
- break;
- case GGML_UNARY_OP_EXP:
- func = ggml_sycl_exp;
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_NORM:
- func = ggml_sycl_norm;
- break;
- case GGML_OP_GROUP_NORM:
- func = ggml_sycl_group_norm;
- break;
- case GGML_OP_CONCAT:
- func = ggml_sycl_op_concat;
- break;
- case GGML_OP_UPSCALE:
- func = ggml_sycl_upscale;
- break;
- case GGML_OP_PAD:
- func = ggml_sycl_pad;
- break;
- case GGML_OP_LEAKY_RELU:
- func = ggml_sycl_leaky_relu;
- break;
- case GGML_OP_RMS_NORM:
- func = ggml_sycl_rms_norm;
- break;
- case GGML_OP_MUL_MAT:
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
- return false;
- }
- func = ggml_sycl_mul_mat;
- break;
- case GGML_OP_MUL_MAT_ID:
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
- return false;
- }
- func = ggml_sycl_mul_mat_id;
- break;
- case GGML_OP_OUT_PROD:
- func = ggml_sycl_op_out_prod;
- break;
- case GGML_OP_SCALE:
- func = ggml_sycl_scale;
- break;
- case GGML_OP_SQR:
- func = ggml_sycl_sqr;
- break;
- case GGML_OP_SQRT:
- func = ggml_sycl_sqrt;
- break;
- case GGML_OP_SIN:
- func = ggml_sycl_sin;
- break;
- case GGML_OP_COS:
- func = ggml_sycl_cos;
- break;
- case GGML_OP_CLAMP:
- func = ggml_sycl_clamp;
- break;
- case GGML_OP_CPY:
- func = ggml_sycl_cpy;
- break;
- case GGML_OP_CONT:
- func = ggml_sycl_dup;
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- func = ggml_sycl_nop;
- break;
- case GGML_OP_DIAG_MASK_INF:
- func = ggml_sycl_diag_mask_inf;
- break;
- case GGML_OP_SOFT_MAX:
- func = ggml_sycl_soft_max;
- break;
- case GGML_OP_ROPE:
- func = ggml_sycl_rope;
- break;
- case GGML_OP_IM2COL:
- func = ggml_sycl_im2col;
- break;
- case GGML_OP_POOL_2D:
- func = ggml_sycl_pool2d;
- break;
- case GGML_OP_SUM:
- func = ggml_sycl_sum;
- break;
- case GGML_OP_SUM_ROWS:
- func = ggml_sycl_sum_rows;
- break;
- case GGML_OP_ARGSORT:
- func = ggml_sycl_argsort;
- break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- func = ggml_sycl_op_timestep_embedding;
- break;
- case GGML_OP_RWKV_WKV6:
- func = ggml_sycl_op_rwkv_wkv6;
- break;
- default:
- return false;
- }
-
- if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
- ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
- }
-
- func(ctx, tensor->src[0], tensor->src[1], tensor);
- return true;
-}
-
-GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
- size_t description_size) try {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
- dpct::device_info prop;
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
- prop, dpct::dev_mgr::instance().get_device(device))));
- snprintf(description, description_size, "%s", prop.get_name());
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-void ggml_backend_sycl_get_device_memory(int device, size_t *free,
- size_t *total) try {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
- ggml_sycl_set_device(device);
-
- /*
- DPCT1009:218: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string was
- inserted. You need to rewrite this code.
- */
- /*
- DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
- device information which may not be supported by all compilers or runtimes.
- You may need to adjust the code.
- */
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend
-
-static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
-
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-
- return sycl_ctx->name.c_str();
-}
-
-static void ggml_backend_sycl_free(ggml_backend_t backend) {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-
- delete sycl_ctx;
- delete backend;
-}
-
-static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
- ggml_tensor *tensor,
- const void *data, size_t offset,
- size_t size) try {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
- const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
- SYCL_CHECK(CHECK_TRY_ERROR(
- (stream)->memcpy((char *)tensor->data + offset, data, size)));
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
- const ggml_tensor *tensor,
- void *data, size_t offset,
- size_t size) try {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
- ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
- GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
- const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
- SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
- data, (const char *)tensor->data + offset, size).wait()));
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
- const ggml_tensor *src,
- ggml_tensor *dst) try {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
- if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
- /*
- DPCT1009:215: SYCL uses exceptions to report errors and does not use the
- error codes. The original code was commented out and a warning string
- was inserted. You need to rewrite this code.
- */
- const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
- SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
- dst->data, src->data, ggml_nbytes(dst)).wait()));
- return true;
- }
-
- return false;
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
- const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
- SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
-
- GGML_UNUSED(backend);
-}
-catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
- ggml_sycl_set_main_device(sycl_ctx->device);
-
-
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor * node = cgraph->nodes[i];
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
- continue;
- }
-#ifndef NDEBUG
- assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- if (node->src[j] != nullptr) {
- assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
- }
- }
-#endif
- bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
- if (!ok) {
- fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
- }
- GGML_ASSERT(ok);
- }
-
- return GGML_STATUS_SUCCESS;
-}
-
-static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
-try
-{
- ggml_backend_sycl_context *sycl_ctx =
- (ggml_backend_sycl_context *)backend->context;
- sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
-
- const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
- // Record the current state of the queue
- SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
-}
-catch (sycl::exception const &exc)
-{
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
- ggml_backend_sycl_context* sycl_ctx = static_cast<ggml_backend_sycl_context*>(backend->context);
- sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
-
- if (ggml_backend_is_sycl(backend)) {
- SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
- } else
- GGML_ABORT("fatal error");
-} catch (sycl::exception const& exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static ggml_backend_i ggml_backend_sycl_interface = {
- /* .get_name = */ ggml_backend_sycl_get_name,
- /* .free = */ ggml_backend_sycl_free,
- /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async,
- /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async,
- /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
- // // TODO: update for the new
- // interface
- /* .synchronize = */ ggml_backend_sycl_synchronize,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_sycl_graph_compute,
- /* .event_record = */ ggml_backend_sycl_event_record,
- /* .event_wait = */ ggml_backend_sycl_event_wait,
-};
-
-static ggml_guid_t ggml_backend_sycl_guid() {
- static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
- return &guid;
-}
-
-bool ggml_backend_is_sycl(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
-}
-
-int ggml_backend_sycl_get_device_count() {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
- return ggml_sycl_info().device_count;
-}
-
-
-// backend device
-
-struct ggml_backend_sycl_device_context {
- int device;
- std::string name;
- std::string description;
-};
-
-static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
- return ctx->name.c_str();
-}
-
-static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
- return ctx->description.c_str();
-}
-
-static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
- ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
- ggml_sycl_set_device(ctx->device);
- SYCL_CHECK(CHECK_TRY_ERROR(
- dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
-}
-
-static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
- props->name = ggml_backend_sycl_device_get_name(dev);
- props->description = ggml_backend_sycl_device_get_description(dev);
- props->type = ggml_backend_sycl_device_get_type(dev);
- ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
- bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
-#ifdef GGML_SYCL_NO_PEER_COPY
- bool events = false;
-#else
- bool events = true;
-#endif
-
- props->caps = {
- /* .async = */ true,
- /* .host_buffer = */ host_buffer,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ events,
- };
-}
-
-static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
- GGML_UNUSED(params);
- ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
- return ggml_backend_sycl_init(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
- return ggml_backend_sycl_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
- GGML_UNUSED(dev);
- return ggml_backend_sycl_host_buffer_type();
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
- GGML_UNUSED(dev);
- GGML_UNUSED(ptr);
- GGML_UNUSED(size);
- GGML_UNUSED(max_tensor_size);
- return nullptr;
-}
-
-static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_CONV_TRANSPOSE_1D:
- {
- ggml_type src0_type = op->src[0]->type;
- ggml_type src1_type = op->src[1]->type;
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- return false;
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_NEG:
- case GGML_UNARY_OP_STEP:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_SIGMOID:
- case GGML_UNARY_OP_HARDSIGMOID:
- case GGML_UNARY_OP_HARDSWISH:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_TANH:
- case GGML_UNARY_OP_EXP:
- return ggml_is_contiguous(op->src[0]);
- default:
- return false;
- }
- break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- {
- struct ggml_tensor * a;
- struct ggml_tensor * b;
- if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
- // TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
- return false;
- }
- } else {
- a = op->src[2];
- b = op->src[1];
- }
- if (a->ne[3] != b->ne[3]) {
- return false;
- }
- ggml_type a_type = a->type;
- if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
- a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
- a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
- a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
- ) {
- if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
- return false;
- }
- }
- ggml_type src0_type = op->src[0]->type;
- if (src0_type == GGML_TYPE_BF16) {
- return false;
- }
- return true;
- } break;
- case GGML_OP_OUT_PROD:
- return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
- case GGML_OP_GET_ROWS:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F16:
- case GGML_TYPE_F32:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- return true;
- default:
- return false;
- }
- } break;
- case GGML_OP_CPY:
- {
- ggml_type src0_type = op->src[0]->type;
- ggml_type src1_type = op->src[1]->type;
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
- return true;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- return false;
- } break;
- case GGML_OP_CONCAT:
- {
- ggml_type src0_type = op->src[0]->type;
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
- } break;
- case GGML_OP_DUP:
- case GGML_OP_ARGMAX:
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_REPEAT:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_NORM:
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- case GGML_OP_LOG:
- case GGML_OP_SUB:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_RMS_NORM:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- return true;
- case GGML_OP_CONT:
- return op->src[0]->type != GGML_TYPE_BF16;
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- return true;
- case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
- case GGML_OP_IM2COL:
- // TODO: add support for the new F32 operations
- return op->src[0]->type == GGML_TYPE_F16;
- case GGML_OP_POOL_2D:
- case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_ARGSORT:
- case GGML_OP_ACC:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_UPSCALE:
- case GGML_OP_PAD:
- case GGML_OP_LEAKY_RELU:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_RWKV_WKV6:
- return true;
- default:
- return false;
- }
-
- GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
- return false;
- }
- ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
- ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
- return buft_ctx->device == sycl_ctx->device;
-}
-
-static int64_t get_op_batch_size(const ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_GET_ROWS:
- return op->ne[1]; // this will increse the speed of prefill in test
- case GGML_OP_MUL_MAT:
- return op->ne[1];
- case GGML_OP_MUL_MAT_ID:
- case GGML_OP_ROPE:
- return op->ne[2];
- default:
- return ggml_nrows(op);
- }
-}
-
-static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- const int min_batch_size = 32;
- return get_op_batch_size(op) >= min_batch_size;
- GGML_UNUSED(dev);
-}
-
-static ggml_backend_event_t
-ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
-
-#ifdef GGML_SYCL_NO_PEER_COPY
- return nullptr;
-#else
- sycl::event *event_ptr = new sycl::event();
-
- return new ggml_backend_event{
- /* .device = */ dev,
- /* .context = */ event_ptr,
- };
-#endif
-}
-
-static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
- GGML_UNUSED(dev);
- if (event == nullptr) {
- return;
- }
-
- if (event->context != nullptr) {
- sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
- delete sycl_event;
- event->context = nullptr;
- }
-
- delete event;
-} catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-
-static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
- GGML_UNUSED(dev);
-
- sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
- SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
-} catch (sycl::exception const &exc) {
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
- << ", line:" << __LINE__ << std::endl;
- std::exit(1);
-}
-
-static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
- /* .get_name = */ ggml_backend_sycl_device_get_name,
- /* .get_description = */ ggml_backend_sycl_device_get_description,
- /* .get_memory = */ ggml_backend_sycl_device_get_memory,
- /* .get_type = */ ggml_backend_sycl_device_get_type,
- /* .get_props = */ ggml_backend_sycl_device_get_props,
- /* .init_backend = */ ggml_backend_sycl_device_init,
- /* .get_buffer_type = */ ggml_backend_sycl_device_get_buffer_type,
- /* .get_host_buffer_type = */ ggml_backend_sycl_device_get_host_buffer_type,
- /* .buffer_from_host_ptr = */ ggml_backend_sycl_device_buffer_from_host_ptr,
- /* .supports_op = */ ggml_backend_sycl_device_supports_op,
- /* .supports_buft = */ ggml_backend_sycl_device_supports_buft,
- /* .offload_op = */ ggml_backend_sycl_device_offload_op,
- /* .event_new = */ ggml_backend_sycl_device_event_new,
- /* .event_free = */ ggml_backend_sycl_device_event_free,
- /* .event_synchronize = */ ggml_backend_sycl_device_event_synchronize,
-};
-
-// backend reg
-
-struct ggml_backend_sycl_reg_context {
- std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
- GGML_UNUSED(reg);
- return GGML_SYCL_NAME;
-}
-
-static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
- ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
- return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
- GGML_ASSERT(index < ctx->devices.size());
- return ctx->devices[index];
-}
-
-static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
- GGML_UNUSED(reg);
-
- // TODO: update to the current function signature
- //if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
- // return (void *)ggml_backend_sycl_split_buffer_type;
- //}
-
- // SYCL doesn't support registering host memory, left here for reference
- // "ggml_backend_register_host_buffer"
- // "ggml_backend_unregister_host_buffer"
- return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
- /* .get_name = */ ggml_backend_sycl_reg_get_name,
- /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count,
- /* .get_device_get = */ ggml_backend_sycl_reg_get_device,
- /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address,
-};
-
-
-// backend registry
-
-ggml_backend_reg_t ggml_backend_sycl_reg() {
- static ggml_backend_reg reg;
- static bool initialized = false;
-
- {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- if (!initialized) {
- ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
-
- for (int i = 0; i < ggml_sycl_info().device_count; i++) {
- ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
- dev_ctx->device = i;
- dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
-
- ggml_sycl_set_device(i);
-
- dpct::device_info prop;
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
- prop, dpct::dev_mgr::instance().get_device(i))));
-
- dev_ctx->description = prop.get_name();
-
- ggml_backend_dev_t dev = new ggml_backend_device {
- /* .interface = */ ggml_backend_sycl_device_interface,
- /* .reg = */ ®,
- /* .context = */ dev_ctx
- };
- ctx->devices.push_back(dev);
- }
-
- reg = ggml_backend_reg {
- /* .interface = */ ggml_backend_sycl_reg_interface,
- /* .context = */ ctx
- };
- }
-
- initialized = true;
- }
-
- return ®
-}
-
-ggml_backend_t ggml_backend_sycl_init(int device) {
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
- ggml_check_sycl();
-
- check_allow_gpu_index(device);
-
- ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
- if (ctx == nullptr) {
- fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
- return nullptr;
- };
-
- ggml_backend_t sycl_backend = new ggml_backend {
- /* .guid = */ ggml_backend_sycl_guid(),
- /* .interface = */ ggml_backend_sycl_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
- /* .context = */ ctx
- };
-
- return sycl_backend;
-}
-
+++ /dev/null
-#include "ggml-vulkan.h"
-#include <vulkan/vulkan_core.h>
-#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
-#include <chrono>
-#endif
-
-#include <vulkan/vulkan.hpp>
-
-#include <algorithm>
-#include <cmath>
-#include <iomanip>
-#include <iostream>
-#include <tuple>
-#include <vector>
-#include <sstream>
-#include <utility>
-#include <memory>
-#include <limits>
-#include <map>
-#include <unordered_map>
-#include <memory>
-#include <mutex>
-#include <future>
-#include <thread>
-
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-vulkan-shaders.hpp"
-
-#define VK_API_VERSION VK_API_VERSION_1_2
-
-#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
-
-#define VK_VENDOR_ID_AMD 0x1002
-#define VK_VENDOR_ID_APPLE 0x106b
-#define VK_VENDOR_ID_INTEL 0x8086
-#define VK_VENDOR_ID_NVIDIA 0x10de
-
-#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32
-
-#define GGML_VK_MAX_NODES 8192
-
-#define MAX_VK_BUFFERS 256
-
-#ifndef K_QUANTS_PER_ITERATION
-#define K_QUANTS_PER_ITERATION 1
-#else
-static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
-#endif
-
-#define VK_CHECK(err, msg) \
- do { \
- vk::Result err_ = (err); \
- if (err_ != vk::Result::eSuccess) { \
- fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \
- #err, to_string(err_).c_str(), __FILE__, __LINE__); \
- exit(1); \
- } \
- } while (0)
-
-#ifdef GGML_VULKAN_DEBUG
-#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
-#else
-#define VK_LOG_DEBUG(msg) ((void) 0)
-#endif // GGML_VULKAN_DEBUG
-
-struct ggml_backend_vk_context;
-
-struct vk_queue {
- uint32_t queue_family_index;
- vk::Queue queue;
- vk::CommandPool pool;
- uint32_t cmd_buffer_idx;
- std::vector<vk::CommandBuffer> cmd_buffers;
-
- vk::PipelineStageFlags stage_flags;
-
- bool transfer_only;
-};
-
-struct vk_pipeline_struct {
- std::string name;
- vk::ShaderModule shader_module;
- vk::DescriptorSetLayout dsl;
- std::vector<vk::DescriptorPool> descriptor_pools;
- std::vector<vk::DescriptorSet> descriptor_sets;
- uint32_t descriptor_set_idx;
- vk::PipelineLayout layout;
- vk::Pipeline pipeline;
- uint32_t push_constant_size;
- uint32_t parameter_count;
- std::array<uint32_t, 3> wg_denoms;
- uint32_t align;
-};
-
-typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
-typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
-
-static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
-
-struct vk_matmul_pipeline_struct {
- vk_pipeline l, m, s;
- vk_pipeline a_l, a_m, a_s;
-};
-
-typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
-
-struct vk_device_struct;
-typedef std::shared_ptr<vk_device_struct> vk_device;
-typedef std::weak_ptr<vk_device_struct> vk_device_ref;
-
-struct vk_buffer_struct;
-typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
-typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
-
-struct ggml_backend_vk_buffer_type_context {
- std::string name;
- vk_device device;
-};
-
-static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
-static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
-static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
-static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
-static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
- /* .get_name = */ ggml_backend_vk_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment,
- /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
- /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size,
- /* .is_host = */ NULL,
-};
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-class vk_memory_logger;
-#endif
-#ifdef GGML_VULKAN_PERF
-class vk_perf_logger;
-#endif
-static void ggml_vk_destroy_buffer(vk_buffer& buf);
-
-struct vk_device_struct {
- std::mutex mutex;
-
- vk::PhysicalDevice physical_device;
- vk::PhysicalDeviceProperties properties;
- std::string name;
- uint64_t max_memory_allocation_size;
- bool fp16;
- vk::Device device;
- uint32_t vendor_id;
- vk_queue compute_queue;
- vk_queue transfer_queue;
- bool single_queue;
- uint32_t subgroup_size;
- bool uma;
-
- size_t idx;
-
- vk_matmul_pipeline pipeline_matmul_f32;
- vk_matmul_pipeline pipeline_matmul_f32_f16;
- vk_matmul_pipeline pipeline_matmul_f16;
- vk_matmul_pipeline pipeline_matmul_f16_f32;
- vk_pipeline pipeline_matmul_split_k_reduce;
-
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
-
- vk_matmul_pipeline pipeline_matmul_id_f32;
- vk_matmul_pipeline pipeline_matmul_id_f16;
- vk_matmul_pipeline pipeline_matmul_id_f16_f32;
-
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
-
- vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
- vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
-
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
- vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
- vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
- vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_acc_f32;
- vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
- vk_pipeline pipeline_mul_f32;
- vk_pipeline pipeline_div_f32;
- vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
- vk_pipeline pipeline_upscale_f32;
- vk_pipeline pipeline_scale_f32;
- vk_pipeline pipeline_sqr_f32;
- vk_pipeline pipeline_sin_f32;
- vk_pipeline pipeline_cos_f32;
- vk_pipeline pipeline_clamp_f32;
- vk_pipeline pipeline_pad_f32;
- vk_pipeline pipeline_repeat_f32;
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
- vk_pipeline pipeline_norm_f32;
- vk_pipeline pipeline_group_norm_f32;
- vk_pipeline pipeline_rms_norm_f32;
- vk_pipeline pipeline_gelu_f32;
- vk_pipeline pipeline_gelu_quick_f32;
- vk_pipeline pipeline_silu_f32;
- vk_pipeline pipeline_relu_f32;
- vk_pipeline pipeline_leaky_relu_f32;
- vk_pipeline pipeline_tanh_f32;
- vk_pipeline pipeline_diag_mask_inf_f32;
- vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
- vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
- vk_pipeline pipeline_argsort_f32;
- vk_pipeline pipeline_sum_rows_f32;
- vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
- vk_pipeline pipeline_timestep_embedding_f32;
- vk_pipeline pipeline_pool2d_f32;
-
- std::unordered_map<std::string, vk_pipeline_ref> pipelines;
- std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
-
- std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
-
- vk::Fence fence;
- vk_buffer sync_staging;
-
- ggml_backend_buffer_type buffer_type;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
- std::unique_ptr<vk_memory_logger> memory_logger;
-#endif
-#ifdef GGML_VULKAN_PERF
- std::unique_ptr<vk_perf_logger> perf_logger;
-#endif
-
- ~vk_device_struct() {
- VK_LOG_DEBUG("destroy device " << name);
-
- device.destroyFence(fence);
-
- ggml_vk_destroy_buffer(sync_staging);
-
- device.destroyCommandPool(compute_queue.pool);
- if (!single_queue) {
- device.destroyCommandPool(transfer_queue.pool);
- }
-
- for (auto& pipeline : pipelines) {
- if (pipeline.second.expired()) {
- continue;
- }
-
- vk_pipeline pl = pipeline.second.lock();
- ggml_vk_destroy_pipeline(device, pl);
- }
- pipelines.clear();
-
- device.destroy();
- }
-};
-
-struct vk_buffer_struct {
- vk::Buffer buffer = VK_NULL_HANDLE;
- vk::DeviceMemory device_memory = VK_NULL_HANDLE;
- vk::MemoryPropertyFlags memory_property_flags;
- void * ptr;
- size_t size = 0;
-
- vk_device device;
-
- ~vk_buffer_struct() {
- if (size == 0) {
- return;
- }
- VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
-
- device->device.freeMemory(device_memory);
- device->device.destroyBuffer(buffer);
- }
-};
-
-struct vk_subbuffer {
- vk_buffer buffer;
- uint64_t offset;
- uint64_t size;
-
- operator vk::DescriptorBufferInfo() const {
- return { buffer->buffer, offset, size };
- }
-};
-
-struct vk_semaphore {
- vk::Semaphore s;
- uint64_t value;
-};
-
-struct vk_submission {
- vk::CommandBuffer buffer;
- std::vector<vk_semaphore> wait_semaphores;
- std::vector<vk_semaphore> signal_semaphores;
-};
-
-typedef std::vector<vk_submission> vk_sequence;
-
-struct vk_mat_mat_push_constants {
- uint32_t M; uint32_t N; uint32_t K;
- uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
- uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
- uint32_t k_split;
- uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
-};
-struct vk_mat_vec_push_constants {
- uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
- uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
- uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
-};
-
-struct vk_mat_mat_id_push_constants {
- uint32_t M; uint32_t N; uint32_t K;
- uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
- uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
- uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
-};
-struct vk_mat_vec_id_push_constants {
- uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
- uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
- uint32_t nei0; uint32_t ne11;
-};
-
-struct vk_op_push_constants {
- uint32_t KX;
- uint32_t KY;
- float param1;
- float param2;
-};
-
-struct vk_op_unary_push_constants {
- uint32_t ne;
- uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
- uint32_t d_offset;
- float param1; float param2;
-};
-
-struct vk_op_binary_push_constants {
- uint32_t ne;
- uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
- uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
- uint32_t d_offset;
- float param1; float param2; int32_t param3;
-};
-
-struct vk_op_diag_mask_push_constants {
- uint32_t ncols;
- uint32_t rows_per_channel;
- int32_t n_past;
-};
-
-struct vk_op_rope_push_constants {
- uint32_t ncols;
- uint32_t n_dims;
- float freq_scale;
- uint32_t p_delta_rows;
- float freq_base;
- float ext_factor;
- float attn_factor;
- float corr_dims[2];
- float theta_scale;
- uint32_t has_ff;
-};
-
-struct vk_op_soft_max_push_constants {
- uint32_t KX;
- uint32_t KY;
- float scale;
- float max_bias;
- float m0;
- float m1;
- uint32_t n_head_log2;
-};
-
-struct vk_op_argsort_push_constants {
- uint32_t ncols;
- uint32_t ncols_pad;
- int32_t order;
-};
-
-struct vk_op_im2col_push_constants {
- uint32_t batch_offset; uint32_t offset_delta;
- uint32_t IC;
- uint32_t IW; uint32_t IH;
- uint32_t OW; uint32_t OH;
- uint32_t KW; uint32_t KH;
- uint32_t pelements;
- uint32_t CHW;
- int32_t s0; int32_t s1;
- int32_t p0; int32_t p1;
- int32_t d0; int32_t d1;
-};
-
-struct vk_op_timestep_embedding_push_constants {
- uint32_t nb1;
- uint32_t dim;
- uint32_t max_period;
-};
-
-struct vk_op_pool2d_push_constants {
- uint32_t IW; uint32_t IH;
- uint32_t OW; uint32_t OH;
- uint32_t OC;
- uint32_t pelements;
- uint32_t op;
- int32_t k0; int32_t k1;
- int32_t s0; int32_t s1;
- int32_t p0; int32_t p1;
-};
-
-// Allow pre-recording command buffers
-struct vk_staging_memcpy {
- vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
-
- void * dst;
- const void * src;
- size_t n;
-};
-
-struct vk_op_upscale_push_constants {
- uint32_t ne; uint32_t d_offset;
- uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
- float sf0; float sf1; float sf2; float sf3;
-};
-
-struct vk_context_struct {
- vk_submission * s;
- std::vector<vk_sequence> seqs;
-
- int exit_tensor_idx;
-
- std::vector<vk_staging_memcpy> in_memcpys;
- std::vector<vk_staging_memcpy> out_memcpys;
-
- vk_queue * q;
-};
-typedef std::shared_ptr<vk_context_struct> vk_context;
-typedef std::weak_ptr<vk_context_struct> vk_context_ref;
-
-struct ggml_vk_garbage_collector {
- std::vector<vk_semaphore> tl_semaphores;
- std::vector<vk_semaphore> semaphores;
- std::vector<vk::Event> events;
- std::vector<vk_buffer> temp_buffers;
- std::vector<vk_context> contexts;
-};
-
-#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
-#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
-
-static std::string format_size(size_t size) {
- const size_t kib = 1024;
- const size_t mib = kib * 1024;
- const size_t gib = mib * 1024;
-
- std::ostringstream oss;
- oss << std::fixed << std::setprecision(2);
-
- if (size >= gib) {
- oss << static_cast<double>(size) / gib << " GiB";
- } else if (size >= mib) {
- oss << static_cast<double>(size) / mib << " MiB";
- } else if (size >= kib) {
- oss << static_cast<double>(size) / kib << " KiB";
- } else {
- oss << size << " B";
- }
-
- return oss.str();
-}
-
-static std::mutex log_mutex;
-
-class vk_memory_logger {
-public:
- vk_memory_logger(): total_device(0), total_host(0) {}
- void log_allocation(vk_buffer_ref buf_ref, size_t size);
- void log_deallocation(vk_buffer_ref buf_ref);
-
-private:
- std::map<vk::Buffer, size_t> allocations; // Track allocations
- size_t total_device;
- size_t total_host;
-};
-#else
-#define VK_LOG_MEMORY(msg) ((void) 0)
-#endif // GGML_VULKAN_MEMORY_DEBUG
-
-#if defined(GGML_VULKAN_PERF)
-
-class vk_perf_logger {
-public:
- void print_timings() {
- std::cerr << "----------------\nVulkan Timings:" << std::endl;
- for (const auto& t : timings) {
- uint64_t total = 0;
- for (const auto& time : t.second) {
- total += time;
- }
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl;
- }
-
- timings.clear();
- }
-
- void log_timing(const ggml_tensor * node, uint64_t time) {
- if (node->op == GGML_OP_UNARY) {
- timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
- return;
- }
- if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
- const uint64_t m = node->src[0]->ne[1];
- const uint64_t n = node->src[1]->ne[1];
- const uint64_t k = node->src[1]->ne[0];
- std::string name = ggml_op_name(node->op);
- if (n == 1) {
- name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
- } else {
- name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
- }
- timings[name].push_back(time);
- return;
- }
- timings[ggml_op_name(node->op)].push_back(time);
- }
-private:
- std::map<std::string, std::vector<uint64_t>> timings;
-};
-#endif // GGML_VULKAN_PERF
-
-struct ggml_backend_vk_context {
- std::string name;
-
- vk_device device;
-
- size_t semaphore_idx, event_idx;
- ggml_vk_garbage_collector gc;
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
- vk::Fence fence;
-
- vk_buffer buffer_pool[MAX_VK_BUFFERS];
-
- vk_context_ref compute_ctx;
- vk_context_ref transfer_ctx;
-
- std::vector<vk_context_ref> tensor_ctxs;
-};
-
-static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
-
-static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
- if (tensor->view_src) {
- return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
- }
- return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
-}
-
-struct ggml_backend_vk_buffer_context {
- vk_device_ref device;
- vk_buffer dev_buffer;
- std::string name;
-
- ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
- device(device),
- dev_buffer(dev_buffer),
- name(name) {
- }
-
- ~ggml_backend_vk_buffer_context() {
- ggml_vk_destroy_buffer(dev_buffer);
- }
-};
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
- std::lock_guard<std::mutex> guard(log_mutex);
- vk_buffer buf = buf_ref.lock();
- const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
- const std::string type = device ? "device" : "host";
- allocations[buf->buffer] = size;
- total_device += device ? size : 0;
- total_host += device ? 0 : size;
- VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
-}
-
-void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
- if (buf_ref.expired() || buf_ref.lock()->size == 0) {
- return;
- }
-
- std::lock_guard<std::mutex> guard(log_mutex);
- vk_buffer buf = buf_ref.lock();
- const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
- std::string type = device ? "device" : "host";
- auto it = allocations.find(buf->buffer);
- total_device -= device ? it->second : 0;
- total_host -= device ? 0 : it->second;
- if (it != allocations.end()) {
- VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
- allocations.erase(it);
- } else {
- VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
- }
-}
-#endif // GGML_VULKAN_MEMORY_DEBUG
-
-struct vk_instance_t {
- vk::Instance instance;
-
- std::vector<size_t> device_indices;
- vk_device devices[GGML_VK_MAX_DEVICES];
-};
-
-static bool vk_instance_initialized = false;
-static vk_instance_t vk_instance;
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
-static size_t vk_skip_checks;
-static size_t vk_output_tensor;
-
-static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_tensor * tensor);
-#endif
-
-typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-
-static void ggml_backend_vk_free(ggml_backend_t backend);
-
-// variables to track number of compiles in progress
-static uint32_t compile_count = 0;
-static std::mutex compile_count_mutex;
-static std::condition_variable compile_count_cond;
-
-static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
- VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
- GGML_ASSERT(parameter_count > 0);
- GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
-
- pipeline = std::make_shared<vk_pipeline_struct>();
- pipeline->name = name;
- pipeline->parameter_count = parameter_count;
- pipeline->push_constant_size = push_constant_size;
- pipeline->wg_denoms = wg_denoms;
- pipeline->align = align;
-
- vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
- pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
-
- std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
- std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
- for (uint32_t i = 0; i < parameter_count; i++) {
- dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
- dsl_binding_flags.push_back({});
- }
-
- vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
-
- vk::PushConstantRange pcr(
- vk::ShaderStageFlagBits::eCompute,
- 0,
- pipeline->push_constant_size
- );
-
- vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
- {},
- dsl_binding);
- descriptor_set_layout_create_info.setPNext(&dslbfci);
- pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
-
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
- vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
- pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-
- pipeline->descriptor_set_idx = 0;
-
- vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
- pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
-
- std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
-
- for (size_t i = 0; i < specialization_constants.size(); i++) {
- specialization_entries[i].constantID = i;
- specialization_entries[i].offset = i * sizeof(uint32_t);
- specialization_entries[i].size = sizeof(uint32_t);
- }
-
- vk::SpecializationInfo specialization_info(
- specialization_entries.size(),
- specialization_entries.data(),
- specialization_constants.size() * sizeof(uint32_t),
- specialization_constants.data()
- );
-
- vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
- vk::PipelineShaderStageCreateFlags(),
- vk::ShaderStageFlagBits::eCompute,
- pipeline->shader_module,
- entrypoint.c_str(),
- &specialization_info);
- vk::ComputePipelineCreateInfo compute_pipeline_create_info(
- vk::PipelineCreateFlags(),
- pipeline_shader_create_info,
- pipeline->layout);
- pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
-
- {
- std::lock_guard<std::mutex> guard(device->mutex);
- device->pipelines.insert({ pipeline->name, pipeline });
- }
-
- {
- std::lock_guard<std::mutex> guard(compile_count_mutex);
- assert(compile_count > 0);
- compile_count--;
-
- // "Progress bar" for shader compiles
- static uint32_t total_compile_count = 0;
- if ((total_compile_count++ % 10) == 0) {
- std::cerr << ".";
- }
- }
- compile_count_cond.notify_all();
-}
-
-static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
- VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
- for (auto& pool : pipeline->descriptor_pools) {
- device.destroyDescriptorPool(pool);
- }
- pipeline->descriptor_pools.clear();
- pipeline->descriptor_sets.clear();
- pipeline->descriptor_set_idx = 0;
-
- device.destroyDescriptorSetLayout(pipeline->dsl);
-
- device.destroyPipelineLayout(pipeline->layout);
-
- device.destroyShaderModule(pipeline->shader_module);
-
- device.destroyPipeline(pipeline->pipeline);
-}
-
-static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
- VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
- device->pipeline_descriptor_set_requirements[pipeline->name] += n;
-}
-
-static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
- std::lock_guard<std::mutex> guard(device->mutex);
-
- for (auto& pair : device->pipeline_descriptor_set_requirements) {
- vk_pipeline pipeline = device->pipelines.at(pair.first).lock();
- const uint64_t n = pair.second;
-
- VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
-
- if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
- // Enough descriptors are available
- continue;
- }
-
- uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
- uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
- uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-
- while (to_alloc > 0) {
- const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
- to_alloc -= alloc_count;
- pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-
- if (pool_idx >= pipeline->descriptor_pools.size()) {
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
- vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
- pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
- }
-
- std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
- for (uint32_t i = 0; i < alloc_count; i++) {
- layouts[i] = pipeline->dsl;
- }
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data());
- std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
- pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
-
- pool_idx++;
- }
- }
-}
-
-static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
- VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")");
- pipeline->descriptor_set_idx = 0;
-}
-
-static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
- VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
- std::lock_guard<std::mutex> guard(device->mutex);
-
- if (q.cmd_buffers.size() > q.cmd_buffer_idx) {
- // Reuse command buffer
- return q.cmd_buffers[q.cmd_buffer_idx++];
- }
-
- vk::CommandBufferAllocateInfo command_buffer_alloc_info(
- q.pool,
- vk::CommandBufferLevel::ePrimary,
- 1);
- const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
- auto buf = cmd_buffers.front();
-
- q.cmd_buffers.push_back(buf);
- q.cmd_buffer_idx++;
-
- return buf;
-}
-
-static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
- VK_LOG_DEBUG("ggml_vk_create_submission()");
- vk_submission s;
- s.buffer = ggml_vk_create_cmd_buffer(device, q);
- s.wait_semaphores = std::move(wait_semaphores);
- s.signal_semaphores = std::move(signal_semaphores);
- return s;
-}
-
-static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
- if (ctx->seqs.empty()) {
- if (fence) {
- ctx->q->queue.submit({}, fence);
- }
- return;
- }
- VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
-
- std::vector<std::vector<uint64_t>> tl_wait_vals;
- std::vector<std::vector<uint64_t>> tl_signal_vals;
- std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
- std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
- std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
- std::vector<vk::SubmitInfo> submit_infos;
- int idx = -1;
- std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
-
- size_t reserve = 0;
-
- for (const auto& sequence : ctx->seqs) {
- reserve += sequence.size();
- }
-
- // Pre-reserve vectors to prevent reallocation, which invalidates pointers
- tl_wait_semaphores.reserve(reserve);
- tl_wait_vals.reserve(reserve);
- tl_signal_semaphores.reserve(reserve);
- tl_signal_vals.reserve(reserve);
- tl_submit_infos.reserve(reserve);
- submit_infos.reserve(reserve);
- stage_flags.reserve(reserve);
-
- for (const auto& sequence : ctx->seqs) {
- for (const auto& submission : sequence) {
- stage_flags.push_back({});
- idx++;
- tl_wait_vals.push_back({});
- tl_wait_semaphores.push_back({});
- tl_signal_vals.push_back({});
- tl_signal_semaphores.push_back({});
- for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
- stage_flags[idx].push_back(ctx->q->stage_flags);
- tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
- tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
- }
- for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
- tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
- tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
- }
- tl_submit_infos.push_back({
- (uint32_t) submission.wait_semaphores.size(),
- tl_wait_vals[idx].data(),
- (uint32_t) submission.signal_semaphores.size(),
- tl_signal_vals[idx].data(),
- });
- tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
- tl_submit_infos[idx].pNext = nullptr;
- vk::SubmitInfo si{
- (uint32_t) submission.wait_semaphores.size(),
- tl_wait_semaphores[idx].data(),
- stage_flags[idx].data(),
- 1,
- &submission.buffer,
- (uint32_t) submission.signal_semaphores.size(),
- tl_signal_semaphores[idx].data(),
- };
- si.setPNext(&tl_submit_infos[idx]);
- submit_infos.push_back(si);
- }
- }
-
- ctx->q->queue.submit(submit_infos, fence);
-
- ctx->seqs.clear();
-}
-
-static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
- VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
- const uint32_t qfsize = queue_family_props.size();
-
- // Try with avoid preferences first
- for (uint32_t i = 0; i < qfsize; i++) {
- if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
- return i;
- }
- }
-
- // Fall back to only required
- for (size_t i = 0; i < qfsize; i++) {
- if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
- return i;
- }
- }
-
- // Fall back to reusing compute queue
- for (size_t i = 0; i < qfsize; i++) {
- if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
- return i;
- }
- }
-
- // Fall back to ignoring min_num_queries
- for (size_t i = 0; i < qfsize; i++) {
- if (queue_family_props[i].queueFlags & required) {
- return i;
- }
- }
-
- // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
- // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
- if (compute_index >= 0) {
- return compute_index;
- }
-
- std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
-
- for(auto &q_family : queue_family_props) {
- std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
- }
- abort();
-}
-
-static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
- VK_LOG_DEBUG("ggml_vk_create_queue()");
- std::lock_guard<std::mutex> guard(device->mutex);
-
- q.queue_family_index = queue_family_index;
- q.transfer_only = transfer_only;
-
- vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
- q.pool = device->device.createCommandPool(command_pool_create_info_compute);
-
- q.cmd_buffer_idx = 0;
-
- q.queue = device->device.getQueue(queue_family_index, queue_index);
-
- q.stage_flags = stage_flags;
-}
-
-static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
- vk_context result = std::make_shared<vk_context_struct>();
- VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
- ctx->gc.contexts.emplace_back(result);
- result->q = &q;
- return result;
-}
-
-static vk_context ggml_vk_create_temporary_context(vk_queue& q) {
- vk_context result = std::make_shared<vk_context_struct>();
- VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
- result->q = &q;
- return result;
-}
-
-static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
- VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
- vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
- vk::SemaphoreCreateInfo ci{};
- ci.setPNext(&tci);
- vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
- ctx->gc.semaphores.push_back({ semaphore, 0 });
- return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
-}
-
-static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
- VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
- if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
- vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
- vk::SemaphoreCreateInfo ci{};
- ci.setPNext(&tci);
- vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
- ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
- }
- return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
-}
-
-static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
- if (ctx->event_idx >= ctx->gc.events.size()) {
- ctx->gc.events.push_back(ctx->device->device.createEvent({}));
- }
- return ctx->gc.events[ctx->event_idx++];
-}
-
-static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
- VK_LOG_DEBUG("ggml_vk_queue_cleanup()");
- std::lock_guard<std::mutex> guard(device->mutex);
-
- // Requires command buffers to be done
- device->device.resetCommandPool(q.pool);
- q.cmd_buffer_idx = 0;
-}
-
-static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
- for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
- vk::MemoryType memory_type = mem_props->memoryTypes[i];
- if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
- (flags & memory_type.propertyFlags) == flags &&
- mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
- return static_cast<int32_t>(i);
- }
- }
- return UINT32_MAX;
-}
-
-static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
- VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
- if (size > device->max_memory_allocation_size) {
- throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
- }
-
- std::lock_guard<std::mutex> guard(device->mutex);
-
- vk_buffer buf = std::make_shared<vk_buffer_struct>();
-
- if (size == 0) {
- buf->size = 0;
- return buf;
- }
-
- vk::BufferCreateInfo buffer_create_info{
- vk::BufferCreateFlags(),
- size,
- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
- vk::SharingMode::eExclusive,
- 0,
- nullptr,
- };
-
- buf->buffer = device->device.createBuffer(buffer_create_info);
-
- vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
-
- vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
-
- uint32_t memory_type_index = UINT32_MAX;
-
- memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
- buf->memory_property_flags = req_flags;
-
- if (memory_type_index == UINT32_MAX && fallback_flags) {
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
- }
-
- if (memory_type_index == UINT32_MAX) {
- device->device.destroyBuffer(buf->buffer);
- throw vk::OutOfDeviceMemoryError("No suitable memory type found");
- }
-
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- } catch (const vk::SystemError& e) {
- if (buf->memory_property_flags != fallback_flags) {
- // Try again with fallback flags
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
-
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- }
- catch (const vk::SystemError& e) {
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- } else {
- // Out of Host/Device memory, clean up buffer
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- }
- buf->ptr = nullptr;
-
- if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
- buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
- }
-
- device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
-
- buf->device = device;
- buf->size = size;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
- device->memory_logger->log_allocation(buf, size);
-#endif
-
- return buf;
-}
-
-static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
- try {
- return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
- } catch (const vk::SystemError& e) {
- std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
- std::cerr << "ggml_vulkan: " << e.what() << std::endl;
- throw e;
- }
-}
-
-static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
- vk_buffer buf;
- try {
- if (device->uma) {
- // Fall back to host memory type
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
- } else {
- // use rebar if available, otherwise fallback to device only visible memory
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
- }
- } catch (const vk::SystemError& e) {
- std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
- std::cerr << "ggml_vulkan: " << e.what() << std::endl;
- throw e;
- }
-
- return buf;
-}
-
-static void ggml_vk_destroy_buffer(vk_buffer& buf) {
- if (buf == nullptr) {
- return;
- }
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
- if (buf->device != nullptr) {
- buf->device->memory_logger->log_deallocation(buf);
- }
-#endif
-
- buf.reset();
-}
-
-static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
- return { buf, 0, VK_WHOLE_SIZE };
-}
-
-static void ggml_vk_sync_buffers(vk_context& ctx) {
- VK_LOG_DEBUG("ggml_vk_sync_buffers()");
-
- const bool transfer_queue = ctx->q->transfer_only;
-
- ctx->s->buffer.pipelineBarrier(
- ctx->q->stage_flags,
- ctx->q->stage_flags,
- {},
- { {
- { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
- { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
- } },
- {},
- {}
- );
-}
-
-static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
- VK_LOG_DEBUG("ggml_vk_wait_events()");
- if (events.empty()) {
- return;
- }
-
- ctx->s->buffer.waitEvents(
- events,
- ctx->q->stage_flags,
- ctx->q->stage_flags,
- {},
- {},
- {}
- );
-}
-
-static void ggml_vk_load_shaders(vk_device& device) {
- VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
-
- std::cerr << "ggml_vulkan: Compiling shaders";
-
- // mulmat
- std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
- std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
- std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
-
- std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
- std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
- std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
-
- std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
- std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
- std::array<uint32_t, 3> s_wg_denoms = { 32, 32, 1 };
-
- uint32_t l_align = 128;
- uint32_t m_align = 64;
- uint32_t s_align = 32;
-
- device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
-
- device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
-
- std::vector<std::future<void>> compiles;
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
- {
- // wait until fewer than N compiles are in progress
- uint32_t N = std::max(1u, std::thread::hardware_concurrency());
- std::unique_lock<std::mutex> guard(compile_count_mutex);
- while (compile_count >= N) {
- compile_count_cond.wait(guard);
- }
- compile_count++;
- }
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
- };
-
- if (device->fp16) {
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- } else {
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
- }
-
- // mul mat vec
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
- // dequant shaders
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-
- // get_rows
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
-
- ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
-
- for (auto &c : compiles) {
- c.wait();
- }
- std::cerr << "Done!" << std::endl;
-}
-
-static vk_device ggml_vk_get_device(size_t idx) {
- VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
-
- if (vk_instance.devices[idx] == nullptr) {
- VK_LOG_DEBUG("Initializing new vk_device");
- vk_device device = std::make_shared<vk_device_struct>();
- vk_instance.devices[idx] = device;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
- device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
-#endif
-#ifdef GGML_VULKAN_PERF
- device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
-#endif
-
- size_t dev_num = vk_instance.device_indices[idx];
-
- std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
-
- if (dev_num >= physical_devices.size()) {
- std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
- throw std::runtime_error("Device not found");
- }
-
- device->physical_device = physical_devices[dev_num];
- const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
-
- bool maintenance4_support = false;
-
- // Check if maintenance4 is supported
- for (const auto& properties : ext_props) {
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
- maintenance4_support = true;
- }
- }
-
- vk::PhysicalDeviceProperties2 props2;
- vk::PhysicalDeviceMaintenance3Properties props3;
- vk::PhysicalDeviceMaintenance4Properties props4;
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
- props2.pNext = &props3;
- props3.pNext = &subgroup_props;
- if (maintenance4_support) {
- subgroup_props.pNext = &props4;
- }
- device->physical_device.getProperties2(&props2);
- device->properties = props2.properties;
-
- const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
-
- if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
- device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
- } else if (maintenance4_support) {
- device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
- } else {
- device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
- }
-
- device->vendor_id = device->properties.vendorID;
- device->subgroup_size = subgroup_props.subgroupSize;
- device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
- bool fp16_storage = false;
- bool fp16_compute = false;
-
- for (const auto& properties : ext_props) {
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
- fp16_storage = true;
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
- fp16_compute = true;
- }
- }
-
- const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
- const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
-
- device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
-
- std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
-
- // Try to find a non-graphics compute queue and transfer-focused queues
- const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
- const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
-
- const float priorities[] = { 1.0f, 1.0f };
- device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
-
- std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
- if (compute_queue_family_index != transfer_queue_family_index) {
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
- } else if(!device->single_queue) {
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
- } else {
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
- }
- vk::DeviceCreateInfo device_create_info;
- std::vector<const char *> device_extensions;
- vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
-
- VkPhysicalDeviceFeatures2 device_features2;
- device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
- device_features2.pNext = nullptr;
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
-
- VkPhysicalDeviceVulkan11Features vk11_features;
- vk11_features.pNext = nullptr;
- vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
- device_features2.pNext = &vk11_features;
-
- VkPhysicalDeviceVulkan12Features vk12_features;
- vk12_features.pNext = nullptr;
- vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
- vk11_features.pNext = &vk12_features;
-
- vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
-
- device->fp16 = device->fp16 && vk12_features.shaderFloat16;
-
- if (!vk11_features.storageBuffer16BitAccess) {
- std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
- throw std::runtime_error("Unsupported device");
- }
-
- device_extensions.push_back("VK_KHR_16bit_storage");
-
-#ifdef GGML_VULKAN_VALIDATE
- device_extensions.push_back("VK_KHR_shader_non_semantic_info");
-#endif
-
- if (device->fp16) {
- device_extensions.push_back("VK_KHR_shader_float16_int8");
- }
- device->name = GGML_VK_NAME + std::to_string(idx);
-
- device_create_info = {
- vk::DeviceCreateFlags(),
- device_queue_create_infos,
- {},
- device_extensions
- };
- device_create_info.setPNext(&device_features2);
- device->device = device->physical_device.createDevice(device_create_info);
-
- // Queues
- ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
-
- // Shaders
- ggml_vk_load_shaders(device);
-
- if (!device->single_queue) {
- const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
- ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
- } else {
- // TODO: Use pointer or reference to avoid copy
- device->transfer_queue = device->compute_queue;
- }
-
- device->buffer_type = {
- /* .iface = */ ggml_backend_vk_buffer_type_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
- /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
- };
-
- device->fence = device->device.createFence({});
-
- device->idx = idx;
-
- return device;
- }
-
- return vk_instance.devices[idx];
-}
-
-
-static void ggml_vk_print_gpu_info(size_t idx) {
- GGML_ASSERT(idx < vk_instance.device_indices.size());
- size_t dev_num = vk_instance.device_indices[idx];
- VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
- GGML_ASSERT(vk_instance_initialized);
-
- std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
- if (dev_num >= devices.size()) {
- std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
- throw std::runtime_error("Device not found");
- }
-
- vk::PhysicalDevice physical_device = devices[dev_num];
- std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
-
- vk::PhysicalDeviceProperties2 props2;
- vk::PhysicalDeviceMaintenance3Properties props3;
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
- vk::PhysicalDeviceDriverProperties driver_props;
- props2.pNext = &props3;
- props3.pNext = &subgroup_props;
- subgroup_props.pNext = &driver_props;
- physical_device.getProperties2(&props2);
-
- const size_t subgroup_size = subgroup_props.subgroupSize;
- const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
- bool fp16_storage = false;
- bool fp16_compute = false;
-
- for (auto properties : ext_props) {
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
- fp16_storage = true;
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
- fp16_compute = true;
- }
- }
-
- const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
- bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
-
- bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
-
- vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
-
- VkPhysicalDeviceFeatures2 device_features2;
- device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
- device_features2.pNext = nullptr;
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
-
- VkPhysicalDeviceVulkan11Features vk11_features;
- vk11_features.pNext = nullptr;
- vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
- device_features2.pNext = &vk11_features;
-
- VkPhysicalDeviceVulkan12Features vk12_features;
- vk12_features.pNext = nullptr;
- vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
- vk11_features.pNext = &vk12_features;
-
- vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
-
- fp16 = fp16 && vk12_features.shaderFloat16;
-
- std::string device_name = props2.properties.deviceName.data();
- GGML_LOG_DEBUG("ggml_vulkan: %d = %s (%s) | uma: %d | fp16: %d | warp size: %d\n",
- idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
-
- if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
- std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
- }
-}
-
-static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
-static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
-
-void ggml_vk_instance_init() {
- if (vk_instance_initialized) {
- return;
- }
- VK_LOG_DEBUG("ggml_vk_instance_init()");
-
- vk_instance_initialized = true;
-
- vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
-
- const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
- const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
-#ifdef __APPLE__
- const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
-#endif
-
- std::vector<const char*> layers;
-
- if (validation_ext) {
- layers.push_back("VK_LAYER_KHRONOS_validation");
- }
- std::vector<const char*> extensions;
- if (validation_ext) {
- extensions.push_back("VK_EXT_validation_features");
- }
-#ifdef __APPLE__
- if (portability_enumeration_ext) {
- extensions.push_back("VK_KHR_portability_enumeration");
- }
-#endif
- vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
-#ifdef __APPLE__
- if (portability_enumeration_ext) {
- instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
- }
-#endif
-
- std::vector<vk::ValidationFeatureEnableEXT> features_enable;
- vk::ValidationFeaturesEXT validation_features;
-
- if (validation_ext) {
- features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
- validation_features = {
- features_enable,
- {},
- };
- validation_features.setPNext(nullptr);
- instance_create_info.setPNext(&validation_features);
- GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
- }
- vk_instance.instance = vk::createInstance(instance_create_info);
-
- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
-
- // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
- char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
- if (devices_env != nullptr) {
- std::string devices(devices_env);
- std::replace(devices.begin(), devices.end(), ',', ' ');
-
- std::stringstream ss(devices);
- size_t tmp;
- while (ss >> tmp) {
- if(tmp >= num_available_devices) {
- std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
- throw std::runtime_error("Invalid Vulkan device index");
- }
- vk_instance.device_indices.push_back(tmp);
- }
- } else {
- std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
- // Make sure at least one device exists
- if (devices.empty()) {
- std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
- GGML_ABORT("fatal error");
- }
-
- // Default to using all dedicated GPUs
- for (size_t i = 0; i < devices.size(); i++) {
- vk::PhysicalDeviceProperties2 new_props;
- vk::PhysicalDeviceDriverProperties new_driver;
- vk::PhysicalDeviceIDProperties new_id;
- new_props.pNext = &new_driver;
- new_driver.pNext = &new_id;
- devices[i].getProperties2(&new_props);
-
- if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
- // Check if there are two physical devices corresponding to the same GPU
- auto old_device = std::find_if(
- vk_instance.device_indices.begin(),
- vk_instance.device_indices.end(),
- [&devices, &new_id](const size_t k){
- vk::PhysicalDeviceProperties2 old_props;
- vk::PhysicalDeviceIDProperties old_id;
- old_props.pNext = &old_id;
- devices[k].getProperties2(&old_props);
- return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
- }
- );
- if (old_device == vk_instance.device_indices.end()) {
- vk_instance.device_indices.push_back(i);
- } else {
- // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
- // This can cause error when splitting layers aross the devices, need to keep only 1
- VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
-
- vk::PhysicalDeviceProperties2 old_props;
- vk::PhysicalDeviceDriverProperties old_driver;
- old_props.pNext = &old_driver;
- devices[*old_device].getProperties2(&old_props);
-
- std::map<vk::DriverId, int> driver_priorities {};
- int old_priority = std::numeric_limits<int>::max();
- int new_priority = std::numeric_limits<int>::max();
-
- // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
- // Smaller number -> higher priority
- switch (old_props.properties.vendorID) {
- case VK_VENDOR_ID_AMD:
- driver_priorities[vk::DriverId::eMesaRadv] = 1;
- driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
- driver_priorities[vk::DriverId::eAmdProprietary] = 3;
- break;
- case VK_VENDOR_ID_INTEL:
- driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
- driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
- break;
- case VK_VENDOR_ID_NVIDIA:
- driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
-#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
- driver_priorities[vk::DriverId::eMesaNvk] = 2;
-#endif
- break;
- }
-
- if (driver_priorities.count(old_driver.driverID)) {
- old_priority = driver_priorities[old_driver.driverID];
- }
- if (driver_priorities.count(new_driver.driverID)) {
- new_priority = driver_priorities[new_driver.driverID];
- }
-
- if (new_priority < old_priority) {
- auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
- vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
- vk_instance.device_indices.push_back(i);
-
- VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
- }
- else {
- VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
- }
- }
- }
- }
-
- // If no dedicated GPUs found, fall back to GPU 0
- if (vk_instance.device_indices.empty()) {
- vk_instance.device_indices.push_back(0);
- }
- }
- GGML_LOG_DEBUG("ggml_vulkan: Found %d Vulkan devices:\n", vk_instance.device_indices.size());
-
-
- for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
- ggml_vk_print_gpu_info(i);
- }
-}
-
-static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
- VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
- ggml_vk_instance_init();
- GGML_ASSERT(idx < vk_instance.device_indices.size());
-
- ctx->name = GGML_VK_NAME + std::to_string(idx);
-
- ctx->device = ggml_vk_get_device(idx);
-
- ctx->semaphore_idx = 0;
- ctx->event_idx = 0;
-
- ctx->prealloc_size_x = 0;
- ctx->prealloc_size_y = 0;
- ctx->prealloc_size_split_k = 0;
-
- ctx->fence = ctx->device->device.createFence({});
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
- const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
- vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
- const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
- vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
-#endif
-}
-
-static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
- VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
- switch (type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return nullptr;
- }
-
- return ctx->device->pipeline_dequant[type];
-}
-
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
- VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return ctx->device->pipeline_matmul_f32;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
- return ctx->device->pipeline_matmul_f32_f16;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
- return ctx->device->pipeline_matmul_f16_f32;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return ctx->device->pipeline_matmul_f16;
- }
-
- if (src1_type != GGML_TYPE_F32) {
- return nullptr;
- }
-
- switch (src0_type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return nullptr;
- }
-
- return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
-}
-
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
- VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
- GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
-
- switch (a_type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return nullptr;
- }
-
- return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
-}
-
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
- VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return ctx->device->pipeline_matmul_id_f32;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
- return ctx->device->pipeline_matmul_id_f16_f32;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return ctx->device->pipeline_matmul_id_f16;
- }
-
- GGML_ASSERT(src1_type == GGML_TYPE_F32);
-
- switch (src0_type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return nullptr;
- }
-
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
-}
-
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
- VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
- GGML_ASSERT(b_type == GGML_TYPE_F32);
-
- switch (a_type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return nullptr;
- }
-
- return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
-}
-
-static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
- VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
- VK_LOG_MEMORY("ggml_vk_pool_malloc");
-
- int best_i = -1;
- size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
- int worst_i = -1;
- size_t worst_size = 0; //largest unused buffer seen so far
- for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
- vk_buffer &b = ctx->buffer_pool[i];
- if (b != nullptr && b->size >= size && b->size < best_size) {
- best_i = i;
- best_size = b->size;
- }
- if (b != nullptr && b->size > worst_size) {
- worst_i = i;
- worst_size = b->size;
- }
- }
- if(best_i != -1) {
- //found the smallest buffer that fits our needs
- vk_buffer b = ctx->buffer_pool[best_i];
- ctx->buffer_pool[best_i].reset();
- return b;
- }
- if(worst_i != -1) {
- //no buffer that fits our needs, resize largest one to save memory
- vk_buffer& b = ctx->buffer_pool[worst_i];
- ggml_vk_destroy_buffer(b);
- }
-
- return ggml_vk_create_buffer_device(ctx->device, size);
-}
-
-static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
- VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
- for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
- vk_buffer& b = ctx->buffer_pool[i];
- if (b == nullptr) {
- b = buffer;
- return;
- }
- }
- std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
- ggml_vk_destroy_buffer(buffer);
-}
-
-// Returns an available temporary buffer that may only be used temporarily, it will be reused
-static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
- // Try to find existing temp buffer with enough capacity
- for (auto& buffer : ctx->gc.temp_buffers) {
- if (buffer->size >= size) {
- return buffer;
- }
- }
-
- VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
-
- // Otherwise create new buffer
- vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
- ctx->gc.temp_buffers.push_back(buf);
-
- return buf;
-}
-
-static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
- VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
- vk_buffer buf = ggml_vk_create_buffer(device, size,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-
- if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
- fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
- size/1024.0/1024.0);
- device->device.freeMemory(buf->device_memory);
- device->device.destroyBuffer(buf->buffer);
- return nullptr;
- }
-
- device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
-
- return buf->ptr;
-}
-
-static void ggml_vk_host_free(vk_device& device, void* ptr) {
- if (ptr == nullptr) {
- return;
- }
- VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
- vk_buffer buf;
- size_t index;
- for (size_t i = 0; i < device->pinned_memory.size(); i++) {
- const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
- const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
- if (ptr >= addr && ptr < endr) {
- buf = std::get<2>(device->pinned_memory[i]);
- index = i;
- break;
- }
- }
- if (buf == nullptr) {
- fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
- return;
- }
-
- ggml_vk_destroy_buffer(buf);
-
- device->pinned_memory.erase(device->pinned_memory.begin() + index);
-}
-
-static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
- buf = nullptr;
- buf_offset = 0;
- for (size_t i = 0; i < device->pinned_memory.size(); i++) {
- const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
- const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
- if (ptr >= addr && ptr < endr) {
- buf = std::get<2>(device->pinned_memory[i]);
- buf_offset = ((const uint8_t *)ptr) - addr;
- break;
- }
- }
-}
-
-static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) {
- vk_submission s;
- s.buffer = ggml_vk_create_cmd_buffer(device, q);
- if (one_time) {
- s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
- } else {
- s.buffer.begin({ vk::CommandBufferUsageFlags{} });
- }
-
- return s;
-}
-
-
-
-static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
- const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
- const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
- const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
- VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
- for (auto& buffer : descriptor_buffer_infos) {
- std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
- }
- std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
- GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
- GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count);
-
- vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
- vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
- ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
-
- subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
- subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
- subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
- pipeline->layout,
- 0,
- { descriptor_set },
- {});
- subctx->s->buffer.dispatch(wg0, wg1, wg2);
-}
-
-static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
- s.buffer.end();
-
- s.wait_semaphores = std::move(wait_semaphores);
- s.signal_semaphores = std::move(signal_semaphores);
-}
-
-static void ggml_vk_ctx_end(vk_context& ctx) {
- VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
- if (ctx->s == nullptr) {
- return;
- }
-
- ctx->s->buffer.end();
- ctx->s = nullptr;
-}
-
-static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
- VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
- if (subctx->s != nullptr) {
- ggml_vk_ctx_end(subctx);
- }
-
- subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) });
- subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
-}
-
-static size_t ggml_vk_align_size(size_t width, size_t align) {
- VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
- return CEIL_DIV(width, align) * align;
-}
-
-static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
- if (memcpys == nullptr) {
- memcpy(dst, src, size);
- } else {
- memcpys->emplace_back(dst, src, size);
- }
-}
-
-static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
- if (device->sync_staging == nullptr || device->sync_staging->size < size) {
- VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
- ggml_vk_destroy_buffer(device->sync_staging);
- device->sync_staging = ggml_vk_create_buffer_check(device, size,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
- }
-}
-
-static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
- VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
- GGML_ASSERT(!ggml_is_contiguous(tensor));
- // Buffer is already mapped
- if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
- std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
- GGML_ABORT("fatal error");
- }
- // Check if src is pinned memory
- vk_buffer buf;
- size_t buf_offset;
- ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
-
- const uint64_t ne0 = tensor->ne[0];
- const uint64_t ne1 = tensor->ne[1];
- const uint64_t ne2 = tensor->ne[2];
- const uint64_t ne3 = tensor->ne[3];
- const uint64_t nb0 = tensor->nb[0];
- const uint64_t nb1 = tensor->nb[1];
- const uint64_t nb2 = tensor->nb[2];
- const uint64_t nb3 = tensor->nb[3];
- const ggml_type type = tensor->type;
- const uint64_t ts = ggml_type_size(type);
- const uint64_t bs = ggml_blck_size(type);
-
- const uint64_t dstnb0 = ts;
- const uint64_t dstnb1 = dstnb0*(ne0/bs);
- const uint64_t dstnb2 = dstnb1*ne1;
- const uint64_t dstnb3 = dstnb2*ne2;
-
- const uint64_t ne = ggml_nelements(tensor);
-
- if (buf != nullptr) {
- // Memory is pinned, use as staging buffer
- std::vector<vk::BufferCopy> slices;
-
- for (uint64_t i3 = 0; i3 < ne3; i3++) {
- for (uint64_t i2 = 0; i2 < ne2; i2++) {
- // Find longest contiguous slice
- if (ne1*nb1 == dstnb2) {
- slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
- } else {
- for (uint64_t i1 = 0; i1 < ne1; i1++) {
- if (ne0*nb0/bs == dstnb1) {
- slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
- } else {
- const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
- const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
- for (uint64_t i0 = 0; i0 < ne0; i0++) {
- slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
- }
- }
- }
- }
- }
- }
-
- ggml_vk_sync_buffers(subctx);
- subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
- return;
- }
-
- if (!sync_staging) {
- GGML_ABORT("Asynchronous write to non-pinned memory not supported");
- }
-
- // Staging buffer required
- vk_buffer& staging = ctx->device->sync_staging;
- const uint64_t copy_size = ts*ne/bs;
- ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
- VkBufferCopy buf_copy{ 0, offset, copy_size };
-
- ggml_vk_sync_buffers(subctx);
- vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
-
- for (uint64_t i3 = 0; i3 < ne3; i3++) {
- for (uint64_t i2 = 0; i2 < ne2; i2++) {
- // Find longest contiguous slice
- if (ne1*nb1 == dstnb2) {
- deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
- } else {
- for (uint64_t i1 = 0; i1 < ne1; i1++) {
- if (ne0*nb0/bs == dstnb1) {
- deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
- } else {
- const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
- const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
- for (uint64_t i0 = 0; i0 < ne0; i0++) {
- deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
- }
- }
- }
- }
- }
- }
-}
-
-static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
- VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
- // Buffer is already mapped
- if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
- std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
- GGML_ABORT("fatal error");
- }
- // Check if src is pinned memory
- vk_buffer buf = nullptr;
- size_t buf_offset;
- ggml_vk_host_get(dst->device, src, buf, buf_offset);
-
- if (buf != nullptr) {
- // Memory is pinned, use as staging buffer
- std::vector<vk::BufferCopy> slices(1);
- if (width == spitch) {
- // Only do single write if stride is equal
- slices[0].srcOffset = buf_offset;
- slices[0].dstOffset = offset;
- slices[0].size = width * height;
- } else {
- slices.resize(height);
- for (size_t i = 0; i < height; i++) {
- slices[i].srcOffset = buf_offset + i * spitch;
- slices[i].dstOffset = offset + i * width;
- slices[i].size = width;
- }
- }
-
- ggml_vk_sync_buffers(subctx);
- subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
- return;
- }
- VK_LOG_DEBUG("STAGING");
-
- if (!sync_staging) {
- GGML_ABORT("Asynchronous write to non-pinned memory not supported");
- }
-
- // Staging buffer required
- const size_t copy_size = width*height;
- ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
-
- vk_buffer& staging_buffer = dst->device->sync_staging;
-
- VkBufferCopy buf_copy = {
- 0,
- offset,
- copy_size};
-
- ggml_vk_sync_buffers(subctx);
- vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
-
- if (width == spitch) {
- deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
- } else {
- for (size_t i = 0; i < height; i++) {
- deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
- }
- }
-}
-
-static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
- VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
- return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
-}
-
-static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
- VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
- // Buffer is already mapped
- if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
- GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
-
- for (size_t i = 0; i < height; i++) {
- memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
- }
- } else {
- vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
- ggml_vk_ctx_begin(dst->device, subctx);
- ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
- ggml_vk_ctx_end(subctx);
-
- for (auto& cpy : subctx->in_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
-
- ggml_vk_submit(subctx, dst->device->fence);
- VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
- dst->device->device.resetFences({ dst->device->fence });
- }
-}
-
-static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
- VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
- ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
-}
-
-static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
- VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
- GGML_ASSERT(width > 0);
- GGML_ASSERT(height > 0);
- GGML_ASSERT(src != nullptr);
-
- // TODO: staging_offset is not used
-
- // Check if dst is pinned memory
- vk_buffer buf = nullptr;
- size_t buf_offset;
- ggml_vk_host_get(src->device, dst, buf, buf_offset);
-
- std::vector<vk::BufferCopy> slices(1);
- if (width == spitch && width == dpitch) {
- // Only do single write if stride is equal
- slices[0].srcOffset = offset;
- slices[0].dstOffset = buf_offset;
- slices[0].size = width * height;
- } else {
- slices.resize(height);
- for (size_t i = 0; i < height; i++) {
- slices[i].srcOffset = offset + i * spitch;
- slices[i].dstOffset = buf_offset + i * dpitch;
- slices[i].size = width;
- }
- }
-
- if (buf != nullptr) {
- // Memory is pinned, use as staging buffer
- ggml_vk_sync_buffers(subctx);
- subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
-
- return;
- }
- VK_LOG_DEBUG("STAGING");
-
- if (!sync_staging) {
- GGML_ABORT("Asynchronous read from non-pinned memory not supported");
- }
-
- // Fall back to staging buffer
- const size_t copy_size = dpitch * height;
- ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
-
- vk_buffer& staging_buffer = src->device->sync_staging;
-
- ggml_vk_sync_buffers(subctx);
- subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
-
- deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
-}
-
-static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
- return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
-}
-
-static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
- VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
-
- // If the device is not an UMA device the memory is host-accessible through rebar. While writing
- // through PCIe is sufficient fast reading back data from PCIe is slower than going through
- // the HW device to host copy path.
- if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
- GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
-
- memcpy(dst, (uint8_t *) src->ptr + offset, size);
- } else {
- vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
- ggml_vk_ctx_begin(src->device, subctx);
- ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
- ggml_vk_ctx_end(subctx);
-
- ggml_vk_submit(subctx, src->device->fence);
- VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
- src->device->device.resetFences({ src->device->fence });
-
- for (auto& cpy : subctx->out_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
- }
-}
-
-static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
- VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
- // Make sure both buffers are on same device
- GGML_ASSERT(src->device == dst->device);
-
- VkBufferCopy bc{ src_offset, dst_offset, size };
-
- vkCmdCopyBuffer(ctx->s->buffer, src->buffer, dst->buffer, 1, &bc);
-}
-
-static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
- if (src->device == dst->device) {
- VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
- // Copy within the device
- vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
- ggml_vk_ctx_begin(src->device, subctx);
- ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
- ggml_vk_ctx_end(subctx);
- ggml_vk_submit(subctx, src->device->fence);
- VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
- src->device->device.resetFences({ src->device->fence });
- } else {
- VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
- // Copy device to device
- ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
-
- // Copy to src staging buffer
- ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
- // Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
- }
-}
-
-static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
- VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
-
- vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
- ggml_vk_ctx_begin(dst->device, subctx);
- subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
- ggml_vk_ctx_end(subctx);
-
- ggml_vk_submit(subctx, dst->device->fence);
- VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
- dst->device->device.resetFences({ dst->device->fence });
-}
-
-static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
- VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
- // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
- // return 4;
- // }
-
- return 1;
-
- GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
- if (m <= 32 || n <= 32) {
- return aligned ? mmp->a_s : mmp->s;
- }
- return aligned ? mmp->a_m : mmp->m;
-
- GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
- return aligned ? mmp->a_m : mmp->m;
-
- GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
- return aligned ? mmp->a_s : mmp->s;
-
- GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
- switch (ctx->device->vendor_id) {
- case VK_VENDOR_ID_AMD:
- return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
- case VK_VENDOR_ID_APPLE:
- return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
- case VK_VENDOR_ID_INTEL:
- return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
- default:
- break;
- }
-
- if (m <= 32 || n <= 32) {
- return aligned ? mmp->a_s : mmp->s;
- }
- if (m <= 64 || n <= 64) {
- return aligned ? mmp->a_m : mmp->m;
- }
- return aligned ? mmp->a_l : mmp->l;
-}
-
-static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
-}
-
-static void ggml_vk_matmul(
- ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
- vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
- uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
- uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
- VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
- ggml_vk_sync_buffers(subctx);
- if (split_k == 1) {
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
- return;
- }
-
- GGML_ASSERT(batch_stride_d == m * n);
-
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
- // Make sure enough workgroups get assigned for split k to work
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
- ggml_vk_sync_buffers(subctx);
- const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
-}
-
-static void ggml_vk_matmul_id(
- ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
- vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
- uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
- uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
- VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
- "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
- "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
- "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
- ggml_vk_sync_buffers(subctx);
- const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
- nei0, nei1, nbi1, ne11 };
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
-}
-
-static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
- return
- tensor->nb[0] == ggml_type_size(tensor->type) &&
- tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
- tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
-}
-
-static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
-
- // Choose "contiguous copy" shader if src/dst are contiguous
- bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
-
- if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
- if (contig) {
- return ctx->device->pipeline_contig_cpy_f32_f32;
- } else {
- return ctx->device->pipeline_cpy_f32_f32;
- }
- }
- if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
- if (contig) {
- return ctx->device->pipeline_contig_cpy_f32_f16;
- } else {
- return ctx->device->pipeline_cpy_f32_f16;
- }
- }
- if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
- if (contig) {
- return ctx->device->pipeline_contig_cpy_f16_f16;
- } else {
- return ctx->device->pipeline_cpy_f16_f16;
- }
- }
-
- std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
- GGML_ABORT("fatal error");
-}
-
-static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
- VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
- std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
- const int tensor_type_size = ggml_type_size(tensor->type);
-
- const uint32_t ne = ggml_nelements(tensor);
- std::array<uint32_t, 3> elements;
-
- if (ne > 262144) {
- elements = { 512, 512, CEIL_DIV(ne, 262144) };
- } else if (ne > 512) {
- elements = { 512, CEIL_DIV(ne, 512), 1 };
- } else {
- elements = { ne, 1, 1 };
- }
-
- const vk_op_unary_push_constants pc = {
- (uint32_t)ne,
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
- 0,
- 0.0f, 0.0f,
- };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
-}
-
-static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
- GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- const uint64_t ne03 = src0->ne[3];
-
- const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- const uint64_t ne13 = src1->ne[3];
-
- const uint64_t ne20 = dst->ne[0];
- const uint64_t ne21 = dst->ne[1];
-
- const uint64_t r2 = ne12 / ne02;
- const uint64_t r3 = ne13 / ne03;
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
- vk_buffer d_Qx;
- size_t qx_buf_offset = 0;
- vk_buffer d_Qy;
- size_t qy_buf_offset = 0;
-
- bool src0_uma = false;
- bool src1_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- src0_uma = d_Qx != nullptr;
- src1_uma = d_Qy != nullptr;
- }
-
- const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
- const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
-
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
-
- const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
-
- if (qx_needs_dequant) {
- // Fall back to dequant + f16 mulmat
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
- }
-
- // Not implemented
- GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
-
- const int x_ne = ne01 * ne00;
- const int y_ne = ne11 * ne10;
- const int d_ne = ne11 * ne01;
-
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
- const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
-
- const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
-
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
-
- const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- vk_pipeline to_fp16_vk_0 = nullptr;
- vk_pipeline to_fp16_vk_1 = nullptr;
-
- if (x_non_contig) {
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
- } else {
- to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
- }
- if (y_non_contig) {
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
- } else {
- to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
- }
- GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
- GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
-
- if (dryrun) {
- const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
- const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
- if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
- (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
- GGML_ABORT("Requested preallocation size is too large");
- }
- if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
- ctx->prealloc_size_x = x_sz_upd;
- }
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
- ctx->prealloc_size_y = y_sz_upd;
- }
- if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
- ctx->prealloc_size_split_k = split_k_size;
- }
-
- // Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
- if (qx_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
- }
- if (qy_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
- }
- if (split_k > 1) {
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
- }
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
- vk_buffer d_X;
- uint64_t x_buf_offset = 0;
- vk_buffer d_Y;
- uint64_t y_buf_offset = 0;
- if (!src0_uma) {
- d_Qx = src0_buf_ctx->dev_buffer;
- qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
- if (!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qy != nullptr);
- }
- if (qx_needs_dequant) {
- d_X = ctx->prealloc_x;
- GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
- } else {
- d_X = d_Qx;
- x_buf_offset = qx_buf_offset;
- GGML_ASSERT(qx_sz == x_sz);
- }
- if (qy_needs_dequant) {
- d_Y = ctx->prealloc_y;
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
- } else {
- d_Y = d_Qy;
- y_buf_offset = qy_buf_offset;
- GGML_ASSERT(qy_sz == y_sz);
- }
-
- if (x_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
- } else if (qx_needs_dequant) {
- const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
- }
- if (y_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
- }
-
- uint32_t stride_batch_x = ne00*ne01;
- uint32_t stride_batch_y = ne10*ne11;
-
- if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
- stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
- }
-
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
- }
-
- // compute
- ggml_vk_matmul(
- ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
- { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
- ne01, ne11, ne10,
- ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
- split_k, ne12*ne13, ne02, ne12, r2, r3
- ); // NOLINT
-}
-
-static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
- GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- const uint64_t ne03 = src0->ne[3];
-
- const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- const uint64_t ne13 = src1->ne[3];
-
- GGML_ASSERT(ne11 == 1);
-
- const uint64_t ne20 = dst->ne[0];
- const uint64_t ne21 = dst->ne[1];
- const uint64_t ne22 = dst->ne[2];
- const uint64_t ne23 = dst->ne[3];
-
- const uint64_t r2 = ne12 / ne02;
- const uint64_t r3 = ne13 / ne03;
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
- vk_buffer d_Qx;
- size_t qx_buf_offset = 0;
- vk_buffer d_Qy;
- size_t qy_buf_offset = 0;
-
- bool src0_uma = false;
- bool src1_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- src0_uma = d_Qx != nullptr;
- src1_uma = d_Qy != nullptr;
- }
-
- const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
- const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-
- const bool qx_needs_dequant = x_non_contig;
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
-
- // Not implemented
- GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
-
- const uint64_t x_ne = ne01 * ne00;
- const uint64_t y_ne = ne11 * ne10;
- const uint64_t d_ne = ne11 * ne01;
-
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- vk_pipeline to_fp16_vk_0 = nullptr;
- vk_pipeline to_fp16_vk_1 = nullptr;
- if (x_non_contig) {
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
- }
- if (y_non_contig) {
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
- } else {
- to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
- }
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
- GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
- GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
- GGML_ASSERT(dmmv != nullptr);
-
- if (dryrun) {
- const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
- if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
- GGML_ABORT("Requested preallocation size is too large");
- }
- if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
- ctx->prealloc_size_x = x_sz_upd;
- }
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
- ctx->prealloc_size_y = y_sz_upd;
- }
-
- // Request descriptor sets
- if (qx_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
- }
- if (qy_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
- }
- ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- vk_buffer d_X;
- uint64_t x_buf_offset = 0;
- vk_buffer d_Y;
- uint64_t y_buf_offset = 0;
- if(!src0_uma) {
- d_Qx = src0_buf_ctx->dev_buffer;
- qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
- if(!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qy != nullptr);
- }
- if (qx_needs_dequant) {
- d_X = ctx->prealloc_x;
- } else {
- d_X = d_Qx;
- x_buf_offset = qx_buf_offset;
- GGML_ASSERT(qx_sz == x_sz);
- }
- if (qy_needs_dequant) {
- d_Y = ctx->prealloc_y;
- } else {
- d_Y = d_Qy;
- y_buf_offset = qy_buf_offset;
- GGML_ASSERT(qy_sz == y_sz);
- }
-
- if (x_non_contig) {
- GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
- }
- if (y_non_contig) {
- GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
- }
-
- uint32_t stride_batch_x = ne00*ne01;
- uint32_t stride_batch_y = ne10*ne11;
-
- if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
- stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
- }
-
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
- }
-
- const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
-
- uint32_t groups_x = ne01;
- uint32_t groups_z = 1;
-
- if (ne01 > max_groups_x) {
- groups_z = 64;
- groups_x /= groups_z;
- }
-
- // compute
- const vk_mat_vec_push_constants pc = {
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
- (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
- };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
- sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
-}
-
-static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
- GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
- GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
- GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- // const uint64_t ne03 = src0->ne[3];
-
- const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- // const uint64_t ne13 = src1->ne[3];
-
- GGML_ASSERT(ne11 == 1);
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
- vk_buffer d_Qy;
- size_t qy_buf_offset = 0;
-
- bool src1_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- src1_uma = d_Qy != nullptr;
- }
-
- const uint64_t x_ne = ne00 * ne01 * ne02;
- const uint64_t y_ne = ne10 * ne11 * ne12;
- const uint64_t d_ne = ne01 * ne11 * ne12;
-
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- if (dryrun) {
- // Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
- const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- if (!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
-
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
- const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
-
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
- const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
-
- // compute
- const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
-}
-
-static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(!ggml_is_permuted(src0));
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- // const uint64_t ne03 = src0->ne[3];
-
- const uint64_t nb01 = src0->nb[1];
- const uint64_t nb02 = src0->nb[2];
-
- // const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- // const uint64_t ne13 = src1->ne[3];
-
- GGML_ASSERT(ne11 == 1);
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
- vk_buffer d_Qy = nullptr;
- size_t qy_buf_offset = 0;
-
- bool src1_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- src1_uma = d_Qy != nullptr;
- }
-
- const uint64_t d_ne = ne01 * ne11 * ne12;
-
- const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
- const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
-
- const uint64_t qx_sz = ggml_nbytes(src0);
- const uint64_t qy_sz = ggml_nbytes(src1);
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- if (dryrun) {
- // Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
- const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- if (!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
-
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
- const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
-
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
- const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
-
- // compute
- const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
-}
-
-static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
- // detect 0213 permutation, and batch size of 1
- src0->nb[0] <= src0->nb[2] &&
- src0->nb[2] <= src0->nb[1] &&
- src0->nb[1] <= src0->nb[3] &&
- src1->nb[0] <= src1->nb[2] &&
- src1->nb[2] <= src1->nb[1] &&
- src1->nb[1] <= src1->nb[3] &&
- src0->ne[3] == 1 &&
- src1->ne[3] == 1) {
- ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
- } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
- !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
- ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
- } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
- ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
- } else {
- ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
- }
-}
-
-static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
- GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
- GGML_ASSERT(ids->type == GGML_TYPE_I32);
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- const uint64_t ne03 = src0->ne[3];
-
- const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- const uint64_t ne13 = src1->ne[3];
-
- const uint64_t nei0 = ids->ne[0];
- const uint64_t nei1 = ids->ne[1];
- GGML_ASSERT(nei0 * nei1 <= 3072);
-
- const uint32_t nbi1 = ids->nb[1];
- const uint32_t nbi2 = ids->nb[2];
-
- const uint64_t ne20 = dst->ne[0];
- const uint64_t ne21 = dst->ne[1];
- const uint64_t ne22 = dst->ne[2];
- const uint64_t ne23 = dst->ne[3];
-
- const uint64_t n_as = ne02;
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
- ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
-
- vk_buffer d_Qx;
- size_t qx_buf_offset = 0;
- vk_buffer d_Qy;
- size_t qy_buf_offset = 0;
- vk_buffer d_ids;
- size_t ids_buf_offset = 0;
-
- bool src0_uma = false;
- bool src1_uma = false;
- bool ids_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
- src0_uma = d_Qx != nullptr;
- src1_uma = d_Qy != nullptr;
- ids_uma = d_ids != nullptr;
- }
-
- const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
- const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
-
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
-
- const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
-
- if (qx_needs_dequant) {
- GGML_ABORT("fatal error");
- }
-
- // Not implemented
- GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
-
- const uint64_t x_ne = ne01 * ne00;
- const uint64_t y_ne = ne11 * ne10;
- const uint64_t d_ne = ne21 * ne20;
-
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
-
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
-
- const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
- const uint64_t ids_sz = nbi2;
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- vk_pipeline to_fp16_vk_0 = nullptr;
- vk_pipeline to_fp16_vk_1 = nullptr;
-
- if (x_non_contig) {
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
- } else {
- to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
- }
- if (y_non_contig) {
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
- } else {
- to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
- }
- GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
- GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
-
- if (dryrun) {
- const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
- if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
- GGML_ABORT("Requested preallocation size is too large");
- }
- if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
- ctx->prealloc_size_x = x_sz_upd;
- }
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
- ctx->prealloc_size_y = y_sz_upd;
- }
-
- // Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
- if (qx_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
- }
- if (qy_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
- }
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- vk_buffer d_X;
- uint64_t x_buf_offset = 0;
- vk_buffer d_Y;
- uint64_t y_buf_offset = 0;
- if (!src0_uma) {
- d_Qx = src0_buf_ctx->dev_buffer;
- qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
- if (!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qy != nullptr);
- }
- if (!ids_uma) {
- d_ids = ids_buf_ctx->dev_buffer;
- ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
- GGML_ASSERT(d_ids != nullptr);
- }
- if (qx_needs_dequant) {
- d_X = ctx->prealloc_x;
- GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
- } else {
- d_X = d_Qx;
- x_buf_offset = qx_buf_offset;
- GGML_ASSERT(qx_sz == x_sz);
- }
- if (qy_needs_dequant) {
- d_Y = ctx->prealloc_y;
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
- } else {
- d_Y = d_Qy;
- y_buf_offset = qy_buf_offset;
- GGML_ASSERT(qy_sz == y_sz);
- }
-
- if (x_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
- } else if (qx_needs_dequant) {
- const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
- }
- if (y_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
- }
-
- uint32_t stride_batch_x = ne00*ne01;
- uint32_t stride_batch_y = ne10*ne11;
-
- if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
- stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
- }
-
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
- }
-
- // compute
- ggml_vk_matmul_id(
- ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
- { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
- ne01, ne21, ne10, ne10, ne10, ne01,
- stride_batch_x, stride_batch_y, ne20*ne21,
- n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
- ); // NOLINT
-}
-
-static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
- GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
- GGML_ASSERT(ids->type == GGML_TYPE_I32);
-
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- const uint64_t ne03 = src0->ne[3];
-
- const uint64_t ne10 = src1->ne[0];
- const uint64_t ne11 = src1->ne[1];
- const uint64_t ne12 = src1->ne[2];
- const uint64_t ne13 = src1->ne[3];
-
- const uint64_t nei0 = ids->ne[0];
- const uint64_t nei1 = ids->ne[1];
-
- const uint64_t nbi2 = ids->nb[2];
-
- GGML_ASSERT(nei1 == 1);
-
- const uint64_t ne20 = dst->ne[0];
- const uint64_t ne21 = dst->ne[1];
- const uint64_t ne22 = dst->ne[2];
- const uint64_t ne23 = dst->ne[3];
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
- ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
-
- vk_buffer d_Qx;
- size_t qx_buf_offset = 0;
- vk_buffer d_Qy;
- size_t qy_buf_offset = 0;
- vk_buffer d_ids;
- size_t ids_buf_offset = 0;
-
- bool src0_uma = false;
- bool src1_uma = false;
- bool ids_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
- ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
- ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
- src0_uma = d_Qx != nullptr;
- src1_uma = d_Qy != nullptr;
- ids_uma = d_ids != nullptr;
- }
-
- const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
- const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-
- const bool qx_needs_dequant = x_non_contig;
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
-
- // Not implemented
- GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
-
- const uint64_t x_ne = ne01 * ne00;
- const uint64_t y_ne = ne11 * ne10;
- const uint64_t d_ne = ne21 * ne20;
-
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
- const uint64_t ids_sz = nbi2;
- const uint64_t d_sz = sizeof(float) * d_ne;
-
- vk_pipeline to_fp16_vk_0 = nullptr;
- vk_pipeline to_fp16_vk_1 = nullptr;
- if (x_non_contig) {
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
- }
- if (y_non_contig) {
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
- } else {
- to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
- }
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
- GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
- GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
- GGML_ASSERT(dmmv != nullptr);
-
- if (dryrun) {
- const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
- if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
- GGML_ABORT("Requested preallocation size is too large");
- }
- if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
- ctx->prealloc_size_x = x_sz_upd;
- }
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
- ctx->prealloc_size_y = y_sz_upd;
- }
-
- // Request descriptor sets
- if (qx_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
- }
- if (qy_needs_dequant) {
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
- }
- ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
- return;
- }
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
- const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
- GGML_ASSERT(d_D != nullptr);
- vk_buffer d_X;
- uint64_t x_buf_offset = 0;
- vk_buffer d_Y;
- uint64_t y_buf_offset = 0;
- if(!src0_uma) {
- d_Qx = src0_buf_ctx->dev_buffer;
- qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_Qx != nullptr);
- }
- if(!src1_uma) {
- d_Qy = src1_buf_ctx->dev_buffer;
- qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Qy != nullptr);
- }
- if(!ids_uma) {
- d_ids = ids_buf_ctx->dev_buffer;
- ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
- GGML_ASSERT(d_ids != nullptr);
- }
- if (qx_needs_dequant) {
- d_X = ctx->prealloc_x;
- } else {
- d_X = d_Qx;
- x_buf_offset = qx_buf_offset;
- GGML_ASSERT(qx_sz == x_sz);
- }
- if (qy_needs_dequant) {
- d_Y = ctx->prealloc_y;
- } else {
- d_Y = d_Qy;
- y_buf_offset = qy_buf_offset;
- GGML_ASSERT(qy_sz == y_sz);
- }
-
- if (x_non_contig) {
- GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
- }
- if (y_non_contig) {
- GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
- }
-
- uint32_t stride_batch_y = ne10*ne11;
-
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
- }
-
- const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
-
- uint32_t groups_x = ne01;
- uint32_t groups_z = 1;
-
- if (ne01 > max_groups_x) {
- groups_z = 64;
- groups_x /= groups_z;
- }
-
- // compute
- const vk_mat_vec_id_push_constants pc = {
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
- (uint32_t)nei0, (uint32_t)ne11,
- };
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
- vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
- sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
-}
-
-static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
- if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
- ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
- } else {
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
- }
-}
-
-static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
- switch (op) {
- case GGML_OP_GET_ROWS:
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
- if (dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_get_rows[src0->type];
- }
- if (dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_get_rows_f32[src0->type];
- }
- return nullptr;
- case GGML_OP_ACC:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_acc_f32;
- }
- return nullptr;
- case GGML_OP_ADD:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_add_f32;
- }
- if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_add_f16_f32_f16;
- }
- return nullptr;
- case GGML_OP_MUL:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_mul_f32;
- }
- return nullptr;
- case GGML_OP_DIV:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_div_f32;
- }
- return nullptr;
- case GGML_OP_CONCAT:
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_concat_f32;
- }
- if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_concat_f16;
- }
- if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
- return ctx->device->pipeline_concat_i32;
- }
- return nullptr;
- case GGML_OP_UPSCALE:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_upscale_f32;
- }
- return nullptr;
- case GGML_OP_SCALE:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_scale_f32;
- }
- return nullptr;
- case GGML_OP_SQR:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_sqr_f32;
- }
- return nullptr;
- case GGML_OP_SIN:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_sin_f32;
- }
- return nullptr;
- case GGML_OP_COS:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_cos_f32;
- }
- return nullptr;
- case GGML_OP_CLAMP:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_clamp_f32;
- }
- return nullptr;
- case GGML_OP_PAD:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_pad_f32;
- }
- return nullptr;
- case GGML_OP_REPEAT:
- if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
- return ctx->device->pipeline_repeat_f32;
- }
- return nullptr;
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- case GGML_OP_DUP:
- return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
- case GGML_OP_NORM:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_norm_f32;
- }
- return nullptr;
- case GGML_OP_GROUP_NORM:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_group_norm_f32;
- }
- return nullptr;
- case GGML_OP_RMS_NORM:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_rms_norm_f32;
- }
- return nullptr;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(dst)) {
- case GGML_UNARY_OP_SILU:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_silu_f32;
- }
- break;
- case GGML_UNARY_OP_GELU:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_gelu_f32;
- }
- break;
- case GGML_UNARY_OP_GELU_QUICK:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_gelu_quick_f32;
- }
- break;
- case GGML_UNARY_OP_RELU:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_relu_f32;
- }
- break;
- case GGML_UNARY_OP_TANH:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_tanh_f32;
- }
- break;
- default:
- break;
- }
- return nullptr;
- case GGML_OP_DIAG_MASK_INF:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_diag_mask_inf_f32;
- }
- return nullptr;
- case GGML_OP_SOFT_MAX:
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
-
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_soft_max_f32;
- }
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_soft_max_f32_f16;
- }
- return nullptr;
- case GGML_OP_ROPE:
- {
- const int mode = ((const int32_t *) dst->op_params)[2];
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
- if (is_neox) {
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_rope_neox_f32;
- }
- if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_rope_neox_f16;
- }
- } else {
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_rope_norm_f32;
- }
- if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_rope_norm_f16;
- }
- }
- return nullptr;
- }
- case GGML_OP_ARGSORT:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
- return ctx->device->pipeline_argsort_f32;
- }
- return nullptr;
- case GGML_OP_SUM_ROWS:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_sum_rows_f32;
- }
- return nullptr;
- case GGML_OP_IM2COL:
- if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_im2col_f32;
- }
- if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_im2col_f32_f16;
- }
- return nullptr;
- case GGML_OP_TIMESTEP_EMBEDDING:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_timestep_embedding_f32;
- }
- return nullptr;
- case GGML_OP_POOL_2D:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_pool2d_f32;
- }
- return nullptr;
- case GGML_OP_LEAKY_RELU:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_leaky_relu_f32;
- }
- return nullptr;
- default:
- return nullptr;
- }
-
- GGML_UNUSED(src2);
-}
-
-static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
- switch (op) {
- case GGML_OP_CPY:
- case GGML_OP_GET_ROWS:
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_CONCAT:
- case GGML_OP_UPSCALE:
- case GGML_OP_SQR:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- case GGML_OP_PAD:
- case GGML_OP_REPEAT:
- return true;
- default:
- return false;
- }
-}
-
-template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
- VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
- if (src1 != nullptr) {
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
- }
- if (src2 != nullptr) {
- std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
- }
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
- std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
- GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
- GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
- GGML_ASSERT(dst->buffer != nullptr);
- const uint64_t ne00 = src0->ne[0];
- const uint64_t ne01 = src0->ne[1];
- const uint64_t ne02 = src0->ne[2];
- const uint64_t ne03 = src0->ne[3];
- const uint64_t ne0 = ne00 * ne01;
-
- const bool use_src1 = src1 != nullptr;
- const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
- const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
- const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
- const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
- const uint64_t ne1 = ne10 * ne11;
- // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0;
-
- const bool use_src2 = src2 != nullptr;
- const uint64_t ne20 = use_src2 ? src2->ne[0] : 0;
- const uint64_t ne21 = use_src2 ? src2->ne[1] : 0;
- const uint64_t ne22 = use_src2 ? src2->ne[2] : 0;
- const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
- const uint64_t ne2 = ne20 * ne21;
-
- const uint64_t ned0 = dst->ne[0];
- const uint64_t ned1 = dst->ne[1];
- const uint64_t ned2 = dst->ne[2];
- const uint64_t ned3 = dst->ne[3];
- const uint64_t ned = ned0 * ned1;
-
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
-
- if (pipeline == nullptr) {
- std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
- if (src1 != nullptr) {
- std::cerr << " and " << ggml_type_name(src1->type);
- }
- std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
- GGML_ABORT("fatal error");
- }
-
- if (dryrun) {
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
- return;
- }
-
- const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
-
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
- ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
-
- vk_buffer d_X = nullptr;
- size_t x_buf_offset = 0;
- vk_buffer d_Y = nullptr;
- size_t y_buf_offset = 0;
- vk_buffer d_Z = nullptr;
- size_t z_buf_offset = 0;
-
- bool src0_uma = false;
- bool src1_uma = false;
- bool src2_uma = false;
-
- if (ctx->device->uma) {
- ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
- src0_uma = d_X != nullptr;
- if (use_src1) {
- ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
- src1_uma = d_Y != nullptr;
- }
- if (use_src2) {
- ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
- src2_uma = d_Z != nullptr;
- }
- }
-
- uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
- uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
- uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
- uint64_t d_sz = ggml_type_size(dst->type) * ned;
-
- vk_buffer d_D = dst_buf_ctx->dev_buffer;
-
- // Workaround for tiny tensor inputs on ROPE
- if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
- y_sz = VK_WHOLE_SIZE;
- }
-
- GGML_ASSERT(d_D != nullptr);
- uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
- GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
- if(!src0_uma) {
- d_X = src0_buf_ctx->dev_buffer;
- x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
- GGML_ASSERT(d_X != nullptr);
- }
- if (use_src1 && !src1_uma) {
- d_Y = src1_buf_ctx->dev_buffer;
- y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
- GGML_ASSERT(d_Y != nullptr);
- }
- if (use_src2 && !src2_uma) {
- d_Z = src2_buf_ctx->dev_buffer;
- z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
- GGML_ASSERT(d_Z != nullptr);
- }
-
- if (op_supports_incontiguous) {
- x_sz = ggml_nbytes(src0);
- y_sz = use_src1 ? ggml_nbytes(src1) : 0;
- z_sz = use_src2 ? ggml_nbytes(src2) : 0;
- d_sz = ggml_nbytes(dst);
-
- if (x_buf_offset + x_sz >= d_X->size) {
- x_sz = VK_WHOLE_SIZE;
- }
- if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
- y_sz = VK_WHOLE_SIZE;
- }
- if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
- z_sz = VK_WHOLE_SIZE;
- }
- if (d_buf_offset + d_sz >= d_D->size) {
- d_sz = VK_WHOLE_SIZE;
- }
- }
-
- std::array<uint32_t, 3> elements;
-
- // Single call if dimension 2 is contiguous
- GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
-
- switch (op) {
- case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_SUM_ROWS:
- {
- const uint32_t nr = ggml_nrows(src0);
- if (nr > 262144) {
- elements = { 512, 512, CEIL_DIV(nr, 262144) };
- } else if (nr > 512) {
- elements = { 512, CEIL_DIV(nr, 512), 1 };
- } else {
- elements = { nr, 1, 1 };
- }
- } break;
- case GGML_OP_GROUP_NORM:
- {
- const uint32_t num_groups = dst->op_params[0];
- elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
- } break;
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_ROPE:
- elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
- break;
- case GGML_OP_GET_ROWS:
- elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
- break;
- case GGML_OP_ARGSORT:
- elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
- break;
- case GGML_OP_IM2COL:
- {
- const bool is_2D = dst->op_params[6] == 1;
-
- const uint32_t IC = src1->ne[is_2D ? 2 : 1];
-
- const uint32_t KH = is_2D ? src0->ne[1] : 1;
- const uint32_t KW = src0->ne[0];
-
- const uint32_t OH = is_2D ? dst->ne[2] : 1;
- const uint32_t OW = dst->ne[1];
-
- const uint32_t batch = src1->ne[is_2D ? 3 : 2];
-
- elements = { OW * KW * KH, OH, batch * IC };
- } break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- {
- const uint32_t dim = dst->op_params[0];
- uint32_t half_ceil = (dim + 1) / 2;
- elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
- } break;
- case GGML_OP_POOL_2D:
- {
- const uint32_t N = dst->ne[3];
- const uint32_t OC = dst->ne[2];
- const uint32_t OH = dst->ne[1];
- const uint32_t OW = dst->ne[0];
- elements = { N * OC * OH * OW, 1, 1};
- } break;
- case GGML_OP_ADD:
- case GGML_OP_DIV:
- case GGML_OP_MUL:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- case GGML_OP_PAD:
- case GGML_OP_REPEAT:
- case GGML_OP_CPY:
- case GGML_OP_CONCAT:
- case GGML_OP_UPSCALE:
- case GGML_OP_UNARY:
- {
- const uint32_t ne = ggml_nelements(dst);
- if (ne > 262144) {
- elements = { 512, 512, CEIL_DIV(ne, 262144) };
- } else if (ne > 512) {
- elements = { 512, CEIL_DIV(ne, 512), 1 };
- } else {
- elements = { ne, 1, 1 };
- }
- } break;
- default:
- elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
- break;
- }
-
- if (!op_supports_incontiguous) {
- if (x_sz != VK_WHOLE_SIZE) {
- x_sz *= ne02 * ne03;
- }
- if (use_src1 && y_sz != VK_WHOLE_SIZE) {
- y_sz *= ne12 * ne13;
- }
- if (use_src2 && z_sz != VK_WHOLE_SIZE) {
- z_sz *= ne22 * ne23;
- }
- if (d_sz != VK_WHOLE_SIZE) {
- d_sz *= ned2 * ned3;
- }
- }
-
- if (op == GGML_OP_SOFT_MAX) {
- // Empty src1 is possible in soft_max, but the shader needs a buffer
- vk_subbuffer subbuf_y;
- if (use_src1) {
- subbuf_y = { d_Y, y_buf_offset, y_sz };
- } else {
- subbuf_y = { d_X, 0, x_sz };
- }
-
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- } else if (op == GGML_OP_ROPE) {
- // Empty src2 is possible in rope, but the shader needs a buffer
- vk_subbuffer subbuf_z;
- if (use_src2) {
- subbuf_z = { d_Z, z_buf_offset, z_sz };
- } else {
- subbuf_z = { d_X, 0, x_sz };
- }
-
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- } else if (op == GGML_OP_IM2COL) {
- // im2col uses only src1 and dst buffers
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- } else if (use_src2) {
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- } else if (use_src1) {
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- } else {
- ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
- }
-}
-
-static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f, 0,
- }, dryrun);
-}
-
-static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
- const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
-
- int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
- int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
- // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
- int offset = dst->op_params[3] / 4; // offset in bytes
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
- d_offset,
- 0.0f, 0.0f, offset,
- }, dryrun);
-}
-
-static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f, 0,
- }, dryrun);
-}
-
-static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f, 0,
- }, dryrun);
-}
-
-static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f, 0,
- }, dryrun);
-}
-
-static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- int * op_params = (int *)dst->op_params;
-
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t src1_type_size = ggml_type_size(src1->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
- (uint32_t)ggml_nelements(dst),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f, op_params[0],
- }, dryrun);
-}
-
-static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
-
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
-
- ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
- (uint32_t)ggml_nelements(dst), 0,
- (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
- sf0, sf1, sf2, sf3,
- }, dryrun);
-}
-
-static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- float * op_params = (float *)dst->op_params;
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- op_params[0], 0.0f
- }, dryrun);
-}
-
-static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- float * op_params = (float *)dst->op_params;
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- op_params[0], op_params[1],
- }, dryrun);
-}
-
-static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
- (uint32_t)ggml_nelements(dst),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
- (uint32_t)ggml_nelements(dst),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- 0,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t src0_type_size = ggml_type_size(src0->type);
- const uint32_t dst_type_size = ggml_type_size(dst->type);
- const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
-
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
- (uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
- d_offset,
- 0.0f, 0.0f,
- }, dryrun);
-}
-
-static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- float * op_params = (float *)dst->op_params;
-
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
-}
-
-static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const int * int_op_params = (const int *)dst->op_params;
- const float * float_op_params = (const float *)dst->op_params;
-
- const uint32_t num_groups = int_op_params[0];
- const float eps = float_op_params[1];
- const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
-}
-
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
-}
-
-static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
-}
-
-static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- int32_t * op_params = (int32_t *)dst->op_params;
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
-}
-
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- float * op_params = (float *)dst->op_params;
-
- float scale = op_params[0];
- float max_bias = op_params[1];
-
- const uint32_t ncols = (uint32_t)src0->ne[0];
- const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
- const uint32_t nrows_y = (uint32_t)src0->ne[1];
-
- const uint32_t n_head_kv = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
-
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
- ncols,
- src1 != nullptr ? nrows_y : (uint32_t)0,
- scale, max_bias,
- m0, m1,
- n_head_log2,
- }, dryrun);
-}
-
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
- const int n_dims = ((int32_t *) dst->op_params)[1];
- // const int mode = ((int32_t *) dst->op_params)[2];
- // const int n_ctx = ((int32_t *) dst->op_params)[3];
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
- const float freq_base = ((float *) dst->op_params)[5];
- const float freq_scale = ((float *) dst->op_params)[6];
- const float ext_factor = ((float *) dst->op_params)[7];
- const float attn_factor = ((float *) dst->op_params)[8];
- const float beta_fast = ((float *) dst->op_params)[9];
- const float beta_slow = ((float *) dst->op_params)[10];
-
- float corr_dims[2];
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
- (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
- freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
- src2 != nullptr,
- }, dryrun);
-}
-
-static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- int32_t * op_params = (int32_t *)dst->op_params;
-
- uint32_t ncols = src0->ne[0];
-
- uint32_t ncols_pad = 1;
- while (ncols_pad < ncols) {
- ncols_pad *= 2;
- }
-
- GGML_ASSERT(ncols_pad <= 1024);
-
- ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
- ncols,
- ncols_pad,
- op_params[0],
- }, dryrun);
-}
-
-static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
-}
-
-static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- const int32_t s0 = dst->op_params[0];
- const int32_t s1 = dst->op_params[1];
- const int32_t p0 = dst->op_params[2];
- const int32_t p1 = dst->op_params[3];
- const int32_t d0 = dst->op_params[4];
- const int32_t d1 = dst->op_params[5];
-
- const bool is_2D = dst->op_params[6] == 1;
-
- const uint32_t IC = src1->ne[is_2D ? 2 : 1];
- const uint32_t IH = is_2D ? src1->ne[1] : 1;
- const uint32_t IW = src1->ne[0];
-
- const uint32_t KH = is_2D ? src0->ne[1] : 1;
- const uint32_t KW = src0->ne[0];
-
- const uint32_t OH = is_2D ? dst->ne[2] : 1;
- const uint32_t OW = dst->ne[1];
-
- const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
-
- const uint32_t pelements = OW * KW * KH;
-
- ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
- batch_offset, offset_delta,
- IC, IW, IH, OW, OH, KW, KH,
- pelements,
- IC * KH * KW,
- s0, s1, p0, p1, d0, d1,
- }, dryrun);
-}
-
-static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const uint32_t dim = dst->op_params[0];
- const uint32_t max_period = dst->op_params[1];
- const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
-
- ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
- nb1, dim, max_period,
- }, dryrun);
-}
-
-static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
- const int32_t k1 = dst->op_params[1];
- const int32_t k0 = dst->op_params[2];
- const int32_t s1 = dst->op_params[3];
- const int32_t s0 = dst->op_params[4];
- const int32_t p1 = dst->op_params[5];
- const int32_t p0 = dst->op_params[6];
-
- const uint32_t IH = src0->ne[1];
- const uint32_t IW = src0->ne[0];
-
- const uint32_t N = dst->ne[3];
-
- const uint32_t OC = dst->ne[2];
- const uint32_t OH = dst->ne[1];
- const uint32_t OW = dst->ne[0];
-
- const uint32_t parallel_elements = N * OC * OH * OW;
-
- ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
- IW, IH, OW, OH, OC,
- parallel_elements,
- op,
- k0, k1, s0, s1, p0, p1,
- }, dryrun);
-}
-
-static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- const float * op_params = (const float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
-}
-
-#ifdef GGML_VULKAN_RUN_TESTS
-static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
- if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
- return;
- }
- i0 = std::max(i0, 5);
- i1 = std::max(i1, 5);
- i2 = std::max(i2, 0);
- fprintf(stderr, " ");
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- fprintf(stderr, "%7d ", idx1);
- }
- fprintf(stderr, "\n");
- for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
- fprintf(stderr, "%7d: ", idx0);
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
- float val;
- if (type == GGML_TYPE_F32) {
- val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
- } else if (type == GGML_TYPE_F16) {
- val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
- } else {
- GGML_ABORT("fatal error");
- }
- fprintf(stderr, "% 7.2f ", val);
- } else {
- fprintf(stderr, " ");
- }
- }
- fprintf(stderr, "\n");
- }
-}
-
-template <typename X_TYPE, typename Y_TYPE>
-static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
- VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
- const size_t x_ne = m * k * batch;
- const size_t y_ne = k * n * batch;
- const size_t d_ne = m * n * batch;
-
- vk_pipeline p;
- std::string shname;
- if (shader_size == 0) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->a_s;
- shname = "F32_ALIGNED_S";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->a_s;
- shname = "F32_F16_ALIGNED_S";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->a_s;
- shname = "F16_F32_ALIGNED_S";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->a_s;
- shname = "F16_ALIGNED_S";
- } else {
- GGML_ABORT("fatal error");
- }
- } else if (shader_size == 1) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->a_m;
- shname = "F32_ALIGNED_M";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->a_m;
- shname = "F32_F16_ALIGNED_M";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->a_m;
- shname = "F16_F32_ALIGNED_M";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->a_m;
- shname = "F16_ALIGNED_M";
- } else {
- GGML_ABORT("fatal error");
- }
- } else if (shader_size == 2) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->a_l;
- shname = "F32_ALIGNED_L";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->a_l;
- shname = "F32_F16_ALIGNED_L";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->a_l;
- shname = "F16_F32_ALIGNED_L";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->a_l;
- shname = "F16_ALIGNED_L";
- } else {
- GGML_ABORT("fatal error");
- }
- } else {
- GGML_ASSERT(0);
- }
-
- const size_t kpad = ggml_vk_align_size(k, p->align);
-
- if (k != kpad) {
- if (shader_size == 0) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->s;
- shname = "F32_S";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->s;
- shname = "F32_F16_S";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->s;
- shname = "F16_F32_S";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->s;
- shname = "F16_S";
- }
- } else if (shader_size == 1) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->m;
- shname = "F32_M";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->m;
- shname = "F32_F16_M";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->m;
- shname = "F16_F32_M";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->m;
- shname = "F16_M";
- }
- } else if (shader_size == 2) {
- if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32->l;
- shname = "F32_L";
- } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f32_f16->l;
- shname = "F32_F16_L";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16_f32->l;
- shname = "F16_F32_L";
- } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
- p = ctx->device->pipeline_matmul_f16->l;
- shname = "F16_L";
- }
- }
- }
-
- ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
- if (split_k > 1) {
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
-
- if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
- // Resize buffer
- if (ctx->prealloc_split_k != nullptr) {
- ggml_vk_destroy_buffer(ctx->prealloc_split_k);
- }
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
- }
- }
-
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
- vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
-
- X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
- Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
- float* d = (float *) malloc(sizeof(float) * d_ne);
-
- for (size_t i = 0; i < x_ne; i++) {
- if (std::is_same<float, X_TYPE>()) {
- x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
- } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
- x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
- } else {
- GGML_ABORT("fatal error");
- }
- }
- for (size_t i = 0; i < y_ne; i++) {
- if (std::is_same<float, Y_TYPE>()) {
- // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
- } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
- // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
- y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
- } else {
- GGML_ABORT("fatal error");
- }
- }
-
- ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
- ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
-
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
- for (size_t i = 0; i < num_it; i++) {
- ggml_vk_ctx_begin(ctx->device, subctx);
- ggml_vk_matmul(
- ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
- m, n, k,
- k, k, m, k*m, k*n, m*n,
- split_k, batch, batch, batch, 1, 1
- );
- ggml_vk_ctx_end(subctx);
- }
-
- auto begin = std::chrono::high_resolution_clock::now();
- ggml_vk_submit(subctx, ctx->fence);
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
- ctx->device->device.resetFences({ ctx->fence });
-
- auto end = std::chrono::high_resolution_clock::now();
- double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
-
- // copy dst to host
- ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
-
- float * d_chk = (float *) malloc(sizeof(float) * d_ne);
-
- ggml_init_params iparams = {
- /*.mem_size =*/ 1024*1024*1024,
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
-
- ggml_context * ggml_ctx = ggml_init(iparams);
-
- ggml_type src0_type;
- ggml_type src1_type;
-
- if (std::is_same<float, X_TYPE>()) {
- src0_type = GGML_TYPE_F32;
- } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
- src0_type = GGML_TYPE_F16;
- } else {
- GGML_ABORT("fatal error");
- }
- if (std::is_same<float, Y_TYPE>()) {
- src1_type = GGML_TYPE_F32;
- } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
- src1_type = GGML_TYPE_F16;
- } else {
- GGML_ABORT("fatal error");
- }
-
- ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
- ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
- ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
-
- src0_ggml->data = x;
- src1_ggml->data = y;
- tensor_ggml->data = d_chk;
-
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
- ggml_build_forward_expand(cgraph, tensor_ggml);
-
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
-
- ggml_free(ggml_ctx);
-
- double avg_err = 0.0;
- int first_err_n = -1;
- int first_err_m = -1;
- int first_err_b = -1;
-
- for (size_t i = 0; i < m*n*batch; i++) {
- double err = std::fabs(d[i] - d_chk[i]);
- avg_err += err;
-
- if (err > 0.05f && first_err_n == -1) {
- first_err_b = i / (m * n);
- first_err_n = (i % (m * n)) / m;
- first_err_m = (i % (m * n)) % m;
- }
- }
-
- avg_err /= m * n;
-
- double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
-
- std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
-
- if (avg_err > 0.1) {
- std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
- std::cerr << "Actual result: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
- std::cerr << std::endl;
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
- std::cerr << "Expected result: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- if (split_k > 1) {
- float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
- ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
-
- std::cerr << "d_buf0: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf1: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf2: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf3: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- free(split_k_buf);
- }
- }
-
- free(d_chk);
-
- ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
- ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
-
- ggml_vk_destroy_buffer(d_X);
- ggml_vk_destroy_buffer(d_Y);
- ggml_vk_destroy_buffer(d_D);
-
- ggml_pipeline_cleanup(p);
- ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
-
- free(x);
- free(y);
- free(d);
-}
-
-static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
- if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
- return;
- }
- i0 = std::max(i0, 5);
- i1 = std::max(i1, 5);
- i2 = std::max(i2, 0);
- i3 = std::max(i3, 0);
- fprintf(stderr, " ");
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- fprintf(stderr, "%7d ", idx1);
- }
- fprintf(stderr, "\n");
- for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
- fprintf(stderr, "%7d: ", idx0);
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
- float val;
- if (tensor->type == GGML_TYPE_F32) {
- val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
- } else if (tensor->type == GGML_TYPE_F16) {
- val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
- } else {
- GGML_ABORT("fatal error");
- }
- fprintf(stderr, "% 7.2f ", val);
- } else {
- fprintf(stderr, " ");
- }
- }
- fprintf(stderr, "\n");
- }
-}
-
-static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
- ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
-}
-
-static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
- if (quant == GGML_TYPE_F32) {
- memcpy(to, from, sizeof(float) * ne);
- return;
- }
-
- const auto * tt = ggml_get_type_traits(quant);
-
- ggml_to_float_t dequant_fn = tt->to_float;
-
- dequant_fn(from, to, ne);
-}
-
-static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
- VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
- const size_t x_sz = sizeof(float) * ne;
- const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
- const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
- float * x = (float *) malloc(x_sz);
- void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
- float * x_ref = (float *) malloc(x_sz);
- ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
-
- for (size_t i = 0; i < ne; i++) {
- x[i] = rand() / (float)RAND_MAX;
- }
-
- vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
-
- ggml_vk_quantize_data(x, qx, ne, quant);
- ggml_vk_dequantize_data(qx, x_ref, ne, quant);
-
- ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
-
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
- ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
-
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
- ggml_vk_ctx_begin(ctx->device, subctx);
- const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
- ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
- ggml_vk_ctx_end(subctx);
-
- auto begin = std::chrono::high_resolution_clock::now();
-
- ggml_vk_submit(subctx, ctx->fence);
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
- ctx->device->device.resetFences({ ctx->fence });
-
- auto end = std::chrono::high_resolution_clock::now();
-
- double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
- ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
-
- int first_err = -1;
-
- double avg_err = 0.0;
- for (size_t i = 0; i < ne; i++) {
- double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
- avg_err += error;
-
- if (first_err < 0 && error > 0.05) {
- first_err = i;
- }
- }
-
- avg_err /= ne;
-
- std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
-
- if (avg_err > 0.1) {
- std::cerr << "first_error = " << first_err << std::endl;
- std::cerr << "Actual result: " << std::endl << std::endl;
- for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
- std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
- }
- std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
- for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
- std::cerr << x_ref[i] << ", ";
- }
- std::cerr << std::endl;
- }
-
- ggml_vk_destroy_buffer(x_buf);
- ggml_vk_destroy_buffer(qx_buf);
-
- free(x);
- free(qx);
- free(x_ref);
- free(x_chk);
-}
-
-static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
- VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
- const size_t x_ne = m * k * batch;
- const size_t y_ne = k * n * batch;
- const size_t d_ne = m * n * batch;
-
- vk_pipeline p;
- std::string shname;
- if (shader_size == 0) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
- shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
- } else if (shader_size == 1) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
- shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
- } else if (shader_size == 2) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
- shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
- } else {
- GGML_ASSERT(0);
- }
-
- const size_t kpad = ggml_vk_align_size(k, p->align);
-
- if (k != kpad) {
- if (shader_size == 0) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
- shname = std::string(ggml_type_name(quant)) + "_S";
- } else if (shader_size == 1) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
- shname = std::string(ggml_type_name(quant)) + "_M";
- } else if (shader_size == 2) {
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
- shname = std::string(ggml_type_name(quant)) + "_L";
- } else {
- GGML_ASSERT(0);
- }
- }
-
- const size_t x_sz = sizeof(float) * x_ne;
- const size_t y_sz = sizeof(float) * y_ne;
- const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
- const size_t d_sz = sizeof(float) * d_ne;
- float * x = (float *) malloc(x_sz);
- float * y = (float *) malloc(y_sz);
- void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- float * d = (float *) malloc(d_sz);
- float * d_chk = (float *) malloc(d_sz);
-
- for (size_t i = 0; i < x_ne; i++) {
- x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
- }
-
- ggml_vk_quantize_data(x, qx, x_ne, quant);
-
- for (size_t i = 0; i < y_ne; i++) {
- // y[i] = rand() / (float)RAND_MAX;
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
- }
-
- ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
- if (split_k > 1) {
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
-
- if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
- // Resize buffer
- if (ctx->prealloc_split_k != nullptr) {
- ggml_vk_destroy_buffer(ctx->prealloc_split_k);
- }
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
- }
- }
-
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
- ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
- ggml_vk_buffer_write(y_buf, 0, y, y_sz);
-
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
- for (size_t i = 0; i < num_it; i++) {
- ggml_vk_ctx_begin(ctx->device, subctx);
- ggml_vk_matmul(
- ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
- m, n, k,
- k, k, m, k*m, k*n, m*n,
- split_k, batch, batch, batch, 1, 1
- );
- ggml_vk_ctx_end(subctx);
- }
-
- auto begin = std::chrono::high_resolution_clock::now();
-
- ggml_vk_submit(subctx, ctx->fence);
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
- ctx->device->device.resetFences({ ctx->fence });
-
- auto end = std::chrono::high_resolution_clock::now();
-
- double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
- ggml_vk_buffer_read(d_buf, 0, d, d_sz);
-
- ggml_init_params iparams = {
- /*.mem_size =*/ 1024*1024*1024,
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ true,
- };
-
- ggml_context * ggml_ctx = ggml_init(iparams);
-
- ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
- ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
- ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
-
- src0_ggml->data = qx;
- src1_ggml->data = y;
- tensor_ggml->data = d_chk;
-
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
- ggml_build_forward_expand(cgraph, tensor_ggml);
-
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
-
- ggml_free(ggml_ctx);
-
- double avg_err = 0.0;
- int first_err_n = -1;
- int first_err_m = -1;
- int first_err_b = -1;
-
- for (size_t i = 0; i < m*n*batch; i++) {
- double err = std::fabs(d[i] - d_chk[i]);
- avg_err += err;
-
- if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
- first_err_b = i / (m * n);
- first_err_n = (i % (m * n)) / m;
- first_err_m = (i % (m * n)) % m;
- }
- }
-
- avg_err /= m * n;
-
- double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
-
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
-
- if (avg_err > 0.01 || std::isnan(avg_err)) {
- std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
- std::cerr << "Actual result: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
- std::cerr << std::endl;
- std::cerr << "Expected result: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- if (split_k > 1) {
- float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
- ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
-
- std::cerr << "d_buf0: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf1: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf2: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- std::cerr << "d_buf3: " << std::endl << std::endl;
- ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
- free(split_k_buf);
- }
- }
-
- ggml_vk_destroy_buffer(qx_buf);
- ggml_vk_destroy_buffer(y_buf);
- ggml_vk_destroy_buffer(d_buf);
-
- free(x);
- free(qx);
- free(y);
- free(d);
- free(d_chk);
-}
-#endif
-
-static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
-#if defined(GGML_VULKAN_RUN_TESTS)
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
-
- ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
-
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
-
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
-
- std::cerr << std::endl;
-
- const std::vector<size_t> vals {
- 8, 8, 8,
- 100, 46, 576,
- 623, 111, 128,
- 100, 46, 558,
- 512, 1, 256,
- 128, 110, 622,
- 511, 511, 127,
- 511, 511, 7,
- 511, 511, 17,
- 49, 49, 128,
- 128, 49, 49,
- 4096, 49, 4096,
- 11008, 49, 4096,
- 4096, 49, 11008,
- 32000, 49, 4096,
- 512, 512, 128,
- 128, 512, 512,
- 4096, 512, 4096,
- 11008, 512, 4096,
- 4096, 512, 11008,
- 32000, 512, 4096,
- };
- const size_t num_it = 1;
- for (size_t i = 0; i < vals.size(); i += 3) {
- ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
- ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
- ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
- std::cerr << std::endl;
- }
-
- GGML_ABORT("fatal error");
-#endif
-
- if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
- // Resize buffer
- if (ctx->prealloc_x != nullptr) {
- ggml_vk_destroy_buffer(ctx->prealloc_x);
- }
- ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
- }
- if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
- // Resize buffer
- if (ctx->prealloc_y != nullptr) {
- ggml_vk_destroy_buffer(ctx->prealloc_y);
- }
- ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
- }
- if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
- // Resize buffer
- if (ctx->prealloc_split_k != nullptr) {
- ggml_vk_destroy_buffer(ctx->prealloc_split_k);
- }
- ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
- }
-}
-
-static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
-
-// Returns true if node has enqueued work into the queue, false otherwise
-// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
-static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
- if (ggml_is_empty(node) || !node->buffer) {
- return false;
- }
-
- VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
- ctx->semaphore_idx = 0;
-
- const ggml_tensor * src0 = node->src[0];
- const ggml_tensor * src1 = node->src[1];
- const ggml_tensor * src2 = node->src[2];
-
- switch (node->op) {
- // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_NONE:
- return false;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(node)) {
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_TANH:
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_REPEAT:
- case GGML_OP_GET_ROWS:
- case GGML_OP_ADD:
- case GGML_OP_ACC:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_CONCAT:
- case GGML_OP_UPSCALE:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- case GGML_OP_PAD:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- case GGML_OP_DUP:
- case GGML_OP_NORM:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_ROPE:
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- case GGML_OP_ARGSORT:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_IM2COL:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_POOL_2D:
- case GGML_OP_LEAKY_RELU:
- break;
- default:
- std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
- GGML_ABORT("fatal error");
- return false;
- }
-
- vk_context compute_ctx;
-
- if (!dryrun) {
- if (ctx->compute_ctx.expired()) {
- compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
- ctx->compute_ctx = compute_ctx;
- ggml_vk_ctx_begin(ctx->device, compute_ctx);
- } else {
- compute_ctx = ctx->compute_ctx.lock();
- }
- }
-
- switch (node->op) {
- case GGML_OP_REPEAT:
- ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_ACC:
- ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_GET_ROWS:
- ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_ADD:
- ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_MUL:
- ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_DIV:
- ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_CONCAT:
- ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_UPSCALE:
- ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_SCALE:
- ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_SQR:
- ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_SIN:
- ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_COS:
- ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_CLAMP:
- ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_PAD:
- ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- case GGML_OP_DUP:
- ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_NORM:
- ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_GROUP_NORM:
- ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_RMS_NORM:
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(node)) {
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_TANH:
- ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_DIAG_MASK_INF:
- ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_SOFT_MAX:
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_ROPE:
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
-
- break;
- case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_SUM_ROWS:
- ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_IM2COL:
- ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_POOL_2D:
- ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_LEAKY_RELU:
- ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
-
- break;
- case GGML_OP_MUL_MAT:
- ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
-
- break;
- case GGML_OP_MUL_MAT_ID:
- ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
-
- break;
- default:
- return false;
- }
-
- if (dryrun) {
- return false;
- }
-
- ctx->tensor_ctxs[node_idx] = compute_ctx;
-
-#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF)
- // Force context reset on each node so that each tensor ends up in its own context
- // and can be run and compared to its CPU equivalent separately
- last_node = true;
-#endif
-
- if (submit || last_node) {
- ggml_vk_ctx_end(compute_ctx);
-
- // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
- if (last_node) {
- compute_ctx->exit_tensor_idx = node_idx_begin;
- }
- else {
- compute_ctx->exit_tensor_idx = -1;
- }
-
- ctx->compute_ctx.reset();
-
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
- if (!ok) {
- if (node->op == GGML_OP_UNARY) {
- std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
- }
- else {
- std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
- }
- }
-
- }
- return true;
-}
-
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
- ggml_backend_buffer * buf = nullptr;
-
- switch (tensor->op) {
- case GGML_OP_ADD:
- case GGML_OP_ACC:
- case GGML_OP_GET_ROWS:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_CONCAT:
- case GGML_OP_UPSCALE:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- case GGML_OP_PAD:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- case GGML_OP_DUP:
- case GGML_OP_NORM:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_ROPE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_NONE:
- case GGML_OP_ARGSORT:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_IM2COL:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_POOL_2D:
- case GGML_OP_LEAKY_RELU:
- case GGML_OP_REPEAT:
- buf = tensor->buffer;
-
- break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(tensor)) {
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_TANH:
- buf = tensor->buffer;
- break;
- default:
- return false;
- }
- break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- buf = tensor->buffer;
-
- break;
- default:
- return false;
- }
-
- if (buf == nullptr) {
- return false;
- }
-
- VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
-
- vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
-
- // always wait for the GPU work to be done for the last submit
- if (tensor_idx == subctx->exit_tensor_idx) {
- use_fence = true;
- }
-
- // Only run if ctx hasn't been submitted yet
- if (!subctx->seqs.empty()) {
-#ifdef GGML_VULKAN_CHECK_RESULTS
- ggml_vk_check_results_0(tensor);
- use_fence = true;
-#endif
-
- // Do staging buffer copies
- for (auto& cpy : subctx->in_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
-
- ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
-
- if (use_fence) {
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
-
- ctx->device->device.resetFences({ ctx->fence });
- }
-#ifdef GGML_VULKAN_CHECK_RESULTS
- ggml_vk_check_results_1(tensor);
-#endif
- }
-
- if (tensor_idx == subctx->exit_tensor_idx) {
- // Do staging buffer copies
- for (auto& cpy : subctx->out_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
- subctx->in_memcpys.clear();
- subctx->out_memcpys.clear();
- }
-
- return true;
-}
-
-// Clean up after graph processing is done
-static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
- VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
- for (auto& buffer : ctx->gc.temp_buffers) {
- ggml_vk_pool_free(ctx, buffer);
- }
- ctx->gc.temp_buffers.clear();
-
- for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) {
- vk_pipeline_ref plr = ctx->device->pipelines[dsr.first];
-
- if (plr.expired()) {
- continue;
- }
-
- vk_pipeline pl = plr.lock();
- ggml_pipeline_cleanup(pl);
- }
-
- ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
- ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
-
- for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
- ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
- }
- ctx->gc.semaphores.clear();
-
- for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
- ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
- }
- ctx->gc.tl_semaphores.clear();
- ctx->semaphore_idx = 0;
-
- ctx->event_idx = 0;
-
- for (auto& event : ctx->gc.events) {
- ctx->device->device.resetEvent(event);
- }
-
- ctx->tensor_ctxs.clear();
- ctx->gc.contexts.clear();
- ctx->device->pipeline_descriptor_set_requirements.clear();
-}
-
-// Clean up on backend free
-static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
- VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
- ggml_vk_graph_cleanup(ctx);
-
- ggml_vk_destroy_buffer(ctx->prealloc_x);
- ggml_vk_destroy_buffer(ctx->prealloc_y);
- ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-
- for (auto& buffer : ctx->buffer_pool) {
- ggml_vk_destroy_buffer(buffer);
- }
-
- ctx->prealloc_size_x = 0;
- ctx->prealloc_size_y = 0;
- ctx->prealloc_size_split_k = 0;
-
- for (auto& event : ctx->gc.events) {
- ctx->device->device.destroyEvent(event);
- }
- ctx->gc.events.clear();
-
- ctx->device->device.destroyFence(ctx->fence);
-}
-
-static int ggml_vk_get_device_count() {
- ggml_vk_instance_init();
-
- return vk_instance.device_indices.size();
-}
-
-static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
- ggml_vk_instance_init();
-
- std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
- vk::PhysicalDeviceProperties props;
- devices[device].getProperties(&props);
-
- snprintf(description, description_size, "%s", props.deviceName.data());
-}
-
-// backend interface
-
-#define UNUSED GGML_UNUSED
-
-// device backend
-
-static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
- return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
-}
-
-static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
- ggml_vk_destroy_buffer(ctx->dev_buffer);
- delete ctx;
-}
-
-static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
- return vk_ptr_base;
-
- UNUSED(buffer);
-}
-
-static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
- VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
- if (tensor->view_src != nullptr) {
- GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
- }
-}
-
-static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
- vk_buffer buf = buf_ctx->dev_buffer;
-
- ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
- vk_buffer buf = buf_ctx->dev_buffer;
-
- ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
- if (ggml_backend_buffer_is_vk(src->buffer)) {
- ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-
- vk_buffer src_buf = src_buf_ctx->dev_buffer;
- vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
-
- ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
-
- return true;
- }
- return false;
-
- UNUSED(buffer);
-}
-
-static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
- ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
- /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
- /* .get_base = */ ggml_backend_vk_buffer_get_base,
- /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
- /* .memset_tensor = */ NULL,
- /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
- /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
- /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
- /* .clear = */ ggml_backend_vk_buffer_clear,
- /* .reset = */ NULL,
-};
-
-// vk buffer type
-static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
- ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
-
- return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
- ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-
- vk_buffer dev_buffer = nullptr;
- try {
- dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
- } catch (const vk::SystemError& e) {
- return nullptr;
- }
-
- ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
-
- return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
-}
-
-static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
- return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
- ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
- return ctx->device->max_memory_allocation_size;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
- return ggml_nbytes(tensor);
-
- UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
- ggml_vk_instance_init();
-
- VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
-
- vk_device dev = ggml_vk_get_device(dev_num);
-
- return &dev->buffer_type;
-}
-
-// host buffer type
-
-static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
- return GGML_VK_NAME "_Host";
-
- UNUSED(buft);
-}
-
-static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
- return GGML_VK_NAME "_Host";
-
- UNUSED(buffer);
-}
-
-static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
- VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
- ggml_vk_host_free(vk_instance.devices[0], buffer->context);
-}
-
-static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
- VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
-
- size += 32; // Behave like the CPU buffer type
- void * ptr = nullptr;
- try {
- ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
- } catch (vk::SystemError& e) {
- std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
- std::cerr << "ggml_vulkan: " << e.what() << std::endl;
- // fallback to cpu buffer
- return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
- }
-
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
- buffer->buft = buft;
- buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
-
- return buffer;
-
- UNUSED(buft);
-}
-
-static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
-
- UNUSED(buft);
-}
-
-// Should be changed to return device-specific host buffer type
-// but that probably requires changes in llama.cpp
-ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
- static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
- /* .iface = */ {
- /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
- /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
- /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
- /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
- },
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
- /* .context = */ nullptr,
- };
-
- // Make sure device 0 is initialized
- ggml_vk_instance_init();
- ggml_vk_get_device(0);
-
- return &ggml_backend_vk_buffer_type_host;
-}
-
-
-// backend
-
-static const char * ggml_backend_vk_name(ggml_backend_t backend) {
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
- return ctx->name.c_str();
-}
-
-static void ggml_backend_vk_free(ggml_backend_t backend) {
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
-
- ggml_vk_cleanup(ctx);
-
- delete ctx;
- delete backend;
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
- return &ctx->device->buffer_type;
-}
-
-static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
-
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
- vk_context transfer_ctx;
-
- if (ctx->transfer_ctx.expired()) {
- // Initialize new transfer context
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ctx->transfer_ctx = transfer_ctx;
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
- } else {
- transfer_ctx = ctx->transfer_ctx.lock();
- }
-
- vk_buffer buf = buf_ctx->dev_buffer;
-
- ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
-
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
- vk_context transfer_ctx;
-
- if (ctx->transfer_ctx.expired()) {
- // Initialize new transfer context
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ctx->transfer_ctx = transfer_ctx;
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
- } else {
- transfer_ctx = ctx->transfer_ctx.lock();
- }
-
- vk_buffer buf = buf_ctx->dev_buffer;
-
- ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
- VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
- ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-
- vk_context transfer_ctx;
-
- if (ctx->transfer_ctx.expired()) {
- // Initialize new transfer context
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
- ctx->transfer_ctx = transfer_ctx;
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
- } else {
- transfer_ctx = ctx->transfer_ctx.lock();
- }
-
- vk_buffer src_buf = src_buf_ctx->dev_buffer;
- vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
-
- ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
- return true;
- }
-
- return false;
-}
-
-static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
- VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- if(ctx->transfer_ctx.expired()) {
- return;
- }
-
- vk_context transfer_ctx = ctx->transfer_ctx.lock();
-
- ggml_vk_ctx_end(transfer_ctx);
-
- for (auto& cpy : transfer_ctx->in_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
-
- ggml_vk_submit(transfer_ctx, ctx->fence);
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
- ctx->device->device.resetFences({ ctx->fence });
-
- for (auto& cpy : transfer_ctx->out_memcpys) {
- memcpy(cpy.dst, cpy.src, cpy.n);
- }
-
- ctx->transfer_ctx.reset();
-}
-
-static bool ggml_vk_is_empty(ggml_tensor * node) {
- return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
-}
-
-static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
- VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
- }
- ggml_vk_preallocate_buffers(ctx);
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
- int last_node = cgraph->n_nodes - 1;
-
- // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
- while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
- last_node -= 1;
- }
-
- // Reserve tensor context space for all nodes
- ctx->tensor_ctxs.resize(cgraph->n_nodes);
-
- bool first_node_in_batch = true; // true if next node will be first node in a batch
- int submit_node_idx = 0; // index to first node in a batch
-
- // submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution
- constexpr int submit_count = 100;
- int submitted_nodes = 0;
- for (int i = 0; i < cgraph->n_nodes; i++) {
- if (first_node_in_batch) {
- submit_node_idx = i;
- }
-
- bool submit = (submitted_nodes >= submit_count) || (i == last_node);
-
-
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
-
- if (enqueued) {
- ++submitted_nodes;
-
-#ifndef GGML_VULKAN_CHECK_RESULTS
- if (first_node_in_batch) {
- first_node_in_batch = false;
- }
-#endif
- }
-
- if (submit) {
- first_node_in_batch = true;
- submitted_nodes = 0;
- }
- }
-
-#ifdef GGML_VULKAN_PERF
- ctx->device->perf_logger->print_timings();
-#endif
-
- ggml_vk_graph_cleanup(ctx);
-
- return GGML_STATUS_SUCCESS;
-
- UNUSED(backend);
-}
-
-// TODO: enable async and synchronize
-static ggml_backend_i ggml_backend_vk_interface = {
- /* .get_name = */ ggml_backend_vk_name,
- /* .free = */ ggml_backend_vk_free,
- /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
- /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
- /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
- /* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
- /* .graph_plan_create = */ NULL,
- /* .graph_plan_free = */ NULL,
- /* .graph_plan_update = */ NULL,
- /* .graph_plan_compute = */ NULL,
- /* .graph_compute = */ ggml_backend_vk_graph_compute,
- /* .event_record = */ NULL,
- /* .event_wait = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_vk_guid() {
- static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
- return &guid;
-}
-
-ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
- VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
-
- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
- ggml_vk_init(ctx, dev_num);
-
- ggml_backend_t vk_backend = new ggml_backend {
- /* .guid = */ ggml_backend_vk_guid(),
- /* .interface = */ ggml_backend_vk_interface,
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
- /* .context = */ ctx,
- };
-
- return vk_backend;
-}
-
-bool ggml_backend_is_vk(ggml_backend_t backend) {
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
-}
-
-int ggml_backend_vk_get_device_count() {
- return ggml_vk_get_device_count();
-}
-
-void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
- int dev_idx = vk_instance.device_indices[device];
- ggml_vk_get_device_description(dev_idx, description, description_size);
-}
-
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
-
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
-
- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
-
- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
- if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
- *total = heap.size;
- *free = heap.size;
- break;
- }
- }
-}
-
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
- size_t device;
- std::string name;
- std::string description;
-};
-
-static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- return ctx->name.c_str();
-}
-
-static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- return ctx->description.c_str();
-}
-
-static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
- ggml_backend_vk_get_device_memory(ctx->device, free, total);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- return ggml_backend_vk_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
- UNUSED(dev);
- return ggml_backend_vk_host_buffer_type();
-}
-
-static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
- UNUSED(dev);
- return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
- props->name = ggml_backend_vk_device_get_name(dev);
- props->description = ggml_backend_vk_device_get_description(dev);
- props->type = ggml_backend_vk_device_get_type(dev);
- ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
- props->caps = {
- /* .async = */ false,
- /* .host_buffer = */ true,
- /* .buffer_from_host_ptr = */ false,
- /* .events = */ false,
- };
-}
-
-static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
- UNUSED(params);
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- return ggml_backend_vk_init(ctx->device);
-}
-
-static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_TANH:
- return ggml_is_contiguous(op->src[0]);
- default:
- return false;
- }
- break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_NL:
- break;
- default:
- return false;
- }
- struct ggml_tensor * a;
- struct ggml_tensor * b;
- if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- } else {
- a = op->src[2];
- b = op->src[1];
- }
- if (a->ne[3] != b->ne[3]) {
- return false;
- }
- if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
- !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
- return false;
- }
-
- return true;
- } break;
- case GGML_OP_GET_ROWS:
- {
- switch (op->src[0]->type) {
- case GGML_TYPE_F32:
- case GGML_TYPE_F16:
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_IQ4_NL:
- return true;
- default:
- return false;
- }
- } break;
- case GGML_OP_CONT:
- case GGML_OP_CPY:
- case GGML_OP_DUP:
- {
- ggml_type src0_type = op->src[0]->type;
- ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
- return true;
- }
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return true;
- }
- return false;
- } break;
- case GGML_OP_REPEAT:
- return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
- case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_NORM:
- case GGML_OP_GROUP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_ADD:
- case GGML_OP_ACC:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_CONCAT:
- case GGML_OP_UPSCALE:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_CLAMP:
- case GGML_OP_PAD:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_ARGSORT:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_IM2COL:
- case GGML_OP_TIMESTEP_EMBEDDING:
- case GGML_OP_POOL_2D:
- case GGML_OP_LEAKY_RELU:
- return true;
- default:
- return false;
- }
-
- UNUSED(dev);
-}
-
-static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
- if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
- return false;
- }
-
- ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
- ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
-
- return buft_ctx->device->idx == ctx->device;
-}
-
-static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
- const int min_batch_size = 32;
-
- return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
- UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
- /* .get_name = */ ggml_backend_vk_device_get_name,
- /* .get_description = */ ggml_backend_vk_device_get_description,
- /* .get_memory = */ ggml_backend_vk_device_get_memory,
- /* .get_type = */ ggml_backend_vk_device_get_type,
- /* .get_props = */ ggml_backend_vk_device_get_props,
- /* .init_backend = */ ggml_backend_vk_device_init,
- /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
- /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
- /* .buffer_from_host_ptr = */ NULL,
- /* .supports_op = */ ggml_backend_vk_device_supports_op,
- /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
- /* .offload_op = */ ggml_backend_vk_device_offload_op,
- /* .event_new = */ NULL,
- /* .event_free = */ NULL,
- /* .event_synchronize = */ NULL,
-};
-
-static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
- UNUSED(reg);
- return GGML_VK_NAME;
-}
-
-static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
- UNUSED(reg);
- return ggml_backend_vk_get_device_count();
-}
-
-static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
- static std::vector<ggml_backend_dev_t> devices;
-
- static bool initialized = false;
-
- {
- static std::mutex mutex;
- std::lock_guard<std::mutex> lock(mutex);
- if (!initialized) {
- for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
- ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
- char desc[256];
- ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
- ctx->device = i;
- ctx->name = GGML_VK_NAME + std::to_string(i);
- ctx->description = desc;
- devices.push_back(new ggml_backend_device {
- /* .iface = */ ggml_backend_vk_device_i,
- /* .reg = */ reg,
- /* .context = */ ctx,
- });
- }
- initialized = true;
- }
- }
-
- GGML_ASSERT(device < devices.size());
- return devices[device];
-}
-
-static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
- /* .get_name = */ ggml_backend_vk_reg_get_name,
- /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
- /* .get_device = */ ggml_backend_vk_reg_get_device,
- /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_vk_reg() {
- static ggml_backend_reg reg = {
- /* .iface = */ ggml_backend_vk_reg_i,
- /* .context = */ nullptr,
- };
-
- return ®
-}
-
-// Extension availability
-static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
-#ifdef GGML_VULKAN_VALIDATE
- bool portability_enumeration_ext = false;
- // Check for portability enumeration extension for MoltenVK support
- for (const auto& properties : instance_extensions) {
- if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
- return true;
- }
- }
- if (!portability_enumeration_ext) {
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
- }
-#endif
- return false;
-
- UNUSED(instance_extensions);
-}
-static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
-#ifdef __APPLE__
- bool portability_enumeration_ext = false;
- // Check for portability enumeration extension for MoltenVK support
- for (const auto& properties : instance_extensions) {
- if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
- return true;
- }
- }
- if (!portability_enumeration_ext) {
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
- }
-#endif
- return false;
-
- UNUSED(instance_extensions);
-}
-
-// checks
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
-static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
- if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
- return;
- }
- for (int j = 0; j < level; j++) {
- std::cerr << " ";
- }
- std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
-
- done.push_back(tensor);
-
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (tensor->src[i] != nullptr) {
- ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
- }
- }
-}
-
-static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
- if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
- return;
- }
- i0 = std::max(i0, 5);
- i1 = std::max(i1, 5);
- i2 = std::max(i2, 0);
- i3 = std::max(i3, 0);
- fprintf(stderr, " ");
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- fprintf(stderr, "%7d ", idx1);
- }
- fprintf(stderr, "\n");
- for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
- fprintf(stderr, "%7d: ", idx0);
- for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
- if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
- float val;
- if (tensor->type == GGML_TYPE_F32) {
- val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
- } else if (tensor->type == GGML_TYPE_F16) {
- val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
- } else if (tensor->type == GGML_TYPE_I32) {
- val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
- } else {
- GGML_ABORT("fatal error");
- }
- fprintf(stderr, "% 7.2f ", val);
- } else {
- fprintf(stderr, " ");
- }
- }
- fprintf(stderr, "\n");
- }
-}
-
-static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
- void * tensor_data = tensor->data;
-
- const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
-
- if (is_gpu) {
- const size_t tensor_size = ggml_nbytes(tensor);
- tensor_data = malloc(tensor_size);
-
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
- vk_buffer buffer_gpu = buf_ctx->dev_buffer;
- ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
- }
-
- std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
- std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
- if (tensor->src[0] != nullptr) {
- std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
- }
- if (tensor->src[1] != nullptr) {
- std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
- }
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(tensor, done);
-
- if (is_gpu) {
- free(tensor_data);
- }
-}
-
-void * comp_result;
-size_t comp_size;
-size_t comp_nb[GGML_MAX_DIMS];
-size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_tensor * tensor) {
- if (tensor->op == GGML_OP_TRANSPOSE) {
- return;
- }
-
- check_counter++;
- if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
- return;
- }
-
- VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
-
- ggml_tensor * src0 = tensor->src[0];
- ggml_tensor * src1 = tensor->src[1];
- ggml_tensor * src2 = tensor->src[2];
-
- struct ggml_init_params iparams = {
- /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
- /*.mem_buffer =*/ NULL,
- /*.no_alloc =*/ false,
- };
-
- struct ggml_context * ggml_ctx = ggml_init(iparams);
-
- struct ggml_tensor * src0_clone = nullptr;
- struct ggml_tensor * src1_clone = nullptr;
- struct ggml_tensor * src2_clone = nullptr;
- struct ggml_tensor * tensor_clone = nullptr;
-
- size_t src0_size;
- size_t src1_size;
- size_t src2_size;
-
- void * src0_buffer = nullptr;
- void * src1_buffer = nullptr;
- void * src2_buffer = nullptr;
-
- if (src0 != nullptr) {
- src0_clone = ggml_dup_tensor(ggml_ctx, src0);
-
- src0_size = ggml_nbytes(src0);
-
- src0_buffer = malloc(src0_size);
- src0_clone->data = src0_buffer;
- if (ggml_backend_buffer_is_host(src0->buffer)) {
- memcpy(src0_clone->data, src0->data, src0_size);
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
- } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
- uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
- if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
- for (int i3 = 0; i3 < src0->ne[3]; i3++) {
- for (int i2 = 0; i2 < src0->ne[2]; i2++) {
- const int idx = i3*src0->ne[2] + i2;
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
- }
- }
-
- src0_clone->nb[0] = src0->nb[0];
- src0_clone->nb[1] = src0->nb[1];
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
- src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
- }
- } else {
- if (offset + src0_size >= buffer_gpu->size) {
- src0_size = buffer_gpu->size - offset;
- }
- ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
- }
- } else {
- GGML_ABORT("fatal error");
- }
-
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(src0, "src0");
- }
- }
- if (src1 != nullptr) {
- src1_clone = ggml_dup_tensor(ggml_ctx, src1);
-
- src1_size = ggml_nbytes(src1);
-
- src1_buffer = malloc(src1_size);
- src1_clone->data = src1_buffer;
- if (ggml_backend_buffer_is_host(src1->buffer)) {
- memcpy(src1_clone->data, src1->data, src1_size);
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
- } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
- uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
- if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
- for (int i3 = 0; i3 < src1->ne[3]; i3++) {
- for (int i2 = 0; i2 < src1->ne[2]; i2++) {
- const int idx = i3*src1->ne[2] + i2;
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
- }
- }
-
- src1_clone->nb[0] = src1->nb[0];
- src1_clone->nb[1] = src1->nb[1];
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
- src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
- }
- } else {
- if (offset + src1_size >= buffer_gpu->size) {
- src1_size = buffer_gpu->size - offset;
- }
- ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
- }
- } else {
- GGML_ABORT("fatal error");
- }
-
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(src1, "src1");
- }
- }
- if (src2 != nullptr) {
- src2_clone = ggml_dup_tensor(ggml_ctx, src2);
-
- src2_size = ggml_nbytes(src2);
-
- src2_buffer = malloc(src2_size);
- src2_clone->data = src2_buffer;
- if (ggml_backend_buffer_is_host(src2->buffer)) {
- memcpy(src2_clone->data, src2->data, src2_size);
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
- } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
- uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
- if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
- for (int i3 = 0; i3 < src2->ne[3]; i3++) {
- for (int i2 = 0; i2 < src2->ne[2]; i2++) {
- const int idx = i3*src2->ne[2] + i2;
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
- }
- }
-
- src2_clone->nb[0] = src2->nb[0];
- src2_clone->nb[1] = src2->nb[1];
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
- src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
- }
- } else {
- if (offset + src2_size >= buffer_gpu->size) {
- src2_size = buffer_gpu->size - offset;
- }
- ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
- }
- } else {
- GGML_ABORT("fatal error");
- }
-
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(src2, "src2");
- }
- }
-
- if (tensor->op == GGML_OP_MUL_MAT) {
- tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
- } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
- tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
- } else if (tensor->op == GGML_OP_MUL) {
- tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
- } else if (tensor->op == GGML_OP_DIV) {
- tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
- } else if (tensor->op == GGML_OP_CONCAT) {
- tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
- } else if (tensor->op == GGML_OP_UPSCALE) {
- tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
- } else if (tensor->op == GGML_OP_SCALE) {
- tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
- } else if (tensor->op == GGML_OP_SQR) {
- tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
- } else if (tensor->op == GGML_OP_SIN) {
- tensor_clone = ggml_sin(ggml_ctx, src0_clone);
- } else if (tensor->op == GGML_OP_COS) {
- tensor_clone = ggml_cos(ggml_ctx, src0_clone);
- } else if (tensor->op == GGML_OP_CLAMP) {
- tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
- } else if (tensor->op == GGML_OP_PAD) {
- tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
- } else if (tensor->op == GGML_OP_REPEAT) {
- tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
- } else if (tensor->op == GGML_OP_ADD) {
- tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
- } else if (tensor->op == GGML_OP_ACC) {
- tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
- } else if (tensor->op == GGML_OP_NORM) {
- tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
- } else if (tensor->op == GGML_OP_GROUP_NORM) {
- tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
- } else if (tensor->op == GGML_OP_RMS_NORM) {
- tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
- } else if (tensor->op == GGML_OP_SOFT_MAX) {
- if (src1 != nullptr) {
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
- } else {
- tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
- }
- } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
- } else if (tensor->op == GGML_OP_ROPE) {
- const int n_dims = ((int32_t *) tensor->op_params)[1];
- const int mode = ((int32_t *) tensor->op_params)[2];
- //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
- const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
- const float freq_base = ((float *) tensor->op_params)[5];
- const float freq_scale = ((float *) tensor->op_params)[6];
- const float ext_factor = ((float *) tensor->op_params)[7];
- const float attn_factor = ((float *) tensor->op_params)[8];
- const float beta_fast = ((float *) tensor->op_params)[9];
- const float beta_slow = ((float *) tensor->op_params)[10];
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
- } else if (tensor->op == GGML_OP_UNARY) {
- switch (ggml_get_unary_op(tensor)) {
- case GGML_UNARY_OP_SILU:
- tensor_clone = ggml_silu(ggml_ctx, src0_clone);
- break;
- case GGML_UNARY_OP_GELU:
- tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
- break;
- case GGML_UNARY_OP_GELU_QUICK:
- tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
- break;
- case GGML_UNARY_OP_RELU:
- tensor_clone = ggml_relu(ggml_ctx, src0_clone);
- break;
- case GGML_UNARY_OP_TANH:
- tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
- break;
- default:
- std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
- GGML_ABORT("fatal error");
- }
- } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
- if (src1 == nullptr) {
- tensor_clone = ggml_dup(ggml_ctx, src0_clone);
- tensor_clone->type = tensor->type;
- } else {
- tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
- }
- } else if (tensor->op == GGML_OP_CONT) {
- tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
- } else if (tensor->op == GGML_OP_RESHAPE) {
- tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
- } else if (tensor->op == GGML_OP_VIEW) {
- tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
- } else if (tensor->op == GGML_OP_PERMUTE) {
- int32_t * params = (int32_t *)tensor->op_params;
- tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
- } else if (tensor->op == GGML_OP_TRANSPOSE) {
- tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
- } else if (tensor->op == GGML_OP_GET_ROWS) {
- tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
- } else if (tensor->op == GGML_OP_ARGSORT) {
- tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
- } else if (tensor->op == GGML_OP_SUM_ROWS) {
- tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
- } else if (tensor->op == GGML_OP_IM2COL) {
- const int32_t s0 = tensor->op_params[0];
- const int32_t s1 = tensor->op_params[1];
- const int32_t p0 = tensor->op_params[2];
- const int32_t p1 = tensor->op_params[3];
- const int32_t d0 = tensor->op_params[4];
- const int32_t d1 = tensor->op_params[5];
-
- const bool is_2D = tensor->op_params[6] == 1;
- tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
- } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
- const int32_t dim = tensor->op_params[0];
- const int32_t max_period = tensor->op_params[1];
- tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
- } else if (tensor->op == GGML_OP_POOL_2D) {
- enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
- const int32_t k0 = tensor->op_params[1];
- const int32_t k1 = tensor->op_params[2];
- const int32_t s0 = tensor->op_params[3];
- const int32_t s1 = tensor->op_params[4];
- const int32_t p0 = tensor->op_params[5];
- const int32_t p1 = tensor->op_params[6];
-
- tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
- } else if (tensor->op == GGML_OP_LEAKY_RELU) {
- const float * op_params = (const float *)tensor->op_params;
- tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
- } else {
- std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
- GGML_ABORT("fatal error");
- }
-
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
- ggml_build_forward_expand(cgraph, tensor_clone);
-
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
-
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- ggml_vk_print_tensor(tensor_clone, "tensor_clone");
- }
-
- comp_size = ggml_nbytes(tensor_clone);
-
- comp_result = malloc(comp_size);
- memcpy(comp_result, tensor_clone->data, comp_size);
- memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
-
- if (src0 != nullptr) {
- free(src0_buffer);
- }
- if (src1 != nullptr) {
- free(src1_buffer);
- }
-
- ggml_free(ggml_ctx);
-
- VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
-}
-
-static void ggml_vk_check_results_1(ggml_tensor * tensor) {
- if (tensor->op == GGML_OP_TRANSPOSE) {
- return;
- }
- if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
- return;
- }
-
- VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
-
- ggml_tensor * src0 = tensor->src[0];
- ggml_tensor * src1 = tensor->src[1];
- ggml_tensor * src2 = tensor->src[2];
-
- void * tensor_data = tensor->data;
-
- if (ggml_backend_buffer_is_vk(tensor->buffer)) {
- size_t tensor_size = ggml_nbytes(tensor);
- tensor_data = malloc(tensor_size);
-
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
- uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
- if (offset + tensor_size >= buffer_gpu->size) {
- tensor_size = buffer_gpu->size - offset;
- }
-
- ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
- }
-
- float first_error_result = -1.0f;
- float first_error_correct = -1.0f;
- std::array<int, 4> first_error = { -1, -1, -1, -1 };
- double avg_err = 0.0;
- size_t counter = 0;
-
- for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
- float correct = 0.0f;
- float result = 0.0f;
-
- if (buffer_size_fit) {
- if (tensor->type == GGML_TYPE_F32) {
- correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
- result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
- } else if (tensor->type == GGML_TYPE_F16) {
- correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
- result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
- } else if (tensor->type == GGML_TYPE_I32) {
- correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
- result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
- } else {
- std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
- }
- } else {
- std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
- GGML_ABORT("fatal error");
- }
-
- if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
- std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
- std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
- if (src0 != nullptr) {
- std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
- }
- if (src1 != nullptr) {
- std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
- }
- if (src2 != nullptr) {
- std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
- }
- std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
- std::cerr << std::endl << "Correct:" << std::endl;
- ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(tensor, done);
- GGML_ABORT("fatal error");
- }
- if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) {
- first_error[0] = i0;
- first_error[1] = i1;
- first_error[2] = i2;
- first_error[3] = i3;
- first_error_result = result;
- first_error_correct = correct;
- }
-
- // Special case, value is infinite, avoid NaN result in avg_err
- // NaN also appears in results, if both are nan error is 0
- if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
- avg_err += std::fabs(correct - result);
- }
- counter++;
- }
- }
- }
- }
-
- avg_err /= counter;
-
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
- std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
- std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
- if (src0 != nullptr) {
- std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
- }
- if (src1 != nullptr) {
- std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
- }
- if (src2 != nullptr) {
- std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
- }
- std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
- std::cerr << std::endl << "Correct:" << std::endl;
- ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(tensor, done);
- }
-
- if (avg_err > 0.05 || std::isnan(avg_err)) {
- std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
- std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
- if (src0 != nullptr) {
- std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
- }
- if (src1 != nullptr) {
- std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
- }
- if (src2 != nullptr) {
- std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
- }
- std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
- std::cerr << std::endl << "Result:" << std::endl;
- ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
- std::cerr << std::endl << "Correct:" << std::endl;
- ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
- std::cerr << std::endl;
- std::vector<const ggml_tensor *> done;
- ggml_vk_print_graph_origin(tensor, done);
- GGML_ABORT("fatal error");
- } else {
- std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
- }
-
- free(comp_result);
- comp_result = nullptr;
- comp_size = 0;
-
- if (ggml_backend_buffer_is_vk(tensor->buffer)) {
- free(tensor_data);
- }
-
- VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
-}
-#endif
+++ /dev/null
-find_package (Threads REQUIRED)
-
-set(TARGET vulkan-shaders-gen)
-add_executable(${TARGET} vulkan-shaders-gen.cpp)
-install(TARGETS ${TARGET} RUNTIME)
-target_compile_features(${TARGET} PRIVATE cxx_std_11)
-target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
- if (idx >= p.ne) {
- return;
- }
-
- const uint offset = p.param3;
- const uint src1_i = idx - offset;
- const uint oz = src1_i / p.nb02;
- const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
- const uint ox = src1_i % p.nb01;
-
- if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
- } else {
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
- }
-}
-
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-
-#define BLOCK_SIZE 1024
-#define ASC 0
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) buffer D {int data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ncols;
- uint ncols_pad;
- uint order;
-} p;
-
-shared int dst_row[BLOCK_SIZE];
-
-void swap(uint idx0, uint idx1) {
- int tmp = dst_row[idx0];
- dst_row[idx0] = dst_row[idx1];
- dst_row[idx1] = tmp;
-}
-
-void main() {
- // bitonic sort
- const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
-
- const uint row_offset = row * p.ncols;
-
- // initialize indices
- if (col < p.ncols_pad) {
- dst_row[col] = col;
- }
- barrier();
-
- for (uint k = 2; k <= p.ncols_pad; k *= 2) {
- for (uint j = k / 2; j > 0; j /= 2) {
- const uint ixj = col ^ j;
- if (col < p.ncols_pad && ixj > col) {
- if ((col & k) == 0) {
- if (dst_row[col] >= p.ncols ||
- (dst_row[ixj] < p.ncols && (p.order == ASC ?
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
- ) {
- swap(col, ixj);
- }
- } else {
- if (dst_row[ixj] >= p.ncols ||
- (dst_row[col] < p.ncols && (p.order == ASC ?
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
- ) {
- swap(col, ixj);
- }
- }
- }
- barrier();
- }
- }
-
- if (col < p.ncols) {
- data_d[row_offset + col] = dst_row[col];
- }
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
- const int dim = p.param3;
-
- if (idx >= p.ne) {
- return;
- }
-
- const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
- const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
- const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
- const uint i2_offset = i2*p.ne21*p.ne20;
- const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
- const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
-
- uint o[4] = {0, 0, 0, 0};
- o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
-
- const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
- const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
- const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
-
- const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
-#else
- if (is_src0) {
- data_d[p.d_offset + dst_idx] = data_a[src0_idx];
- } else {
- data_d[p.d_offset + dst_idx] = data_b[src1_idx];
- }
-#endif
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-#extension GL_EXT_control_flow_attributes : require
-
-const uint num_threads = 128;
-
-layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- uint idx = get_idx();
-
- // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
- const uint num_iter = 4;
-
- // fast path for when all four iterations are in-bounds
- if (idx + (num_iter-1)*num_threads < p.ne) {
- [[unroll]] for (uint i = 0; i < num_iter; ++i) {
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
-#else
- data_d[p.d_offset + idx] = data_a[idx];
-#endif
- idx += num_threads;
- }
- } else {
- [[unroll]] for (uint i = 0; i < num_iter; ++i) {
- if (idx >= p.ne) {
- continue;
- }
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
-#else
- data_d[p.d_offset + idx] = data_a[idx];
-#endif
- idx += num_threads;
- }
- }
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
-#else
- data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
-#endif
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.x * 16;
-
- if (i >= p.nel) {
- return;
- }
-
- [[unroll]] for (uint l = 0; l < 16; l++) {
- data_b[i + l] = D_TYPE(data_a[i + l]);
- }
-}
+++ /dev/null
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#endif
-
-#if defined(DATA_A_F32)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-#endif
-
-#if defined(DATA_A_F16)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-#endif
-
-#if defined(DATA_A_Q4_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-}
-#endif
-
-#if defined(DATA_A_Q4_1)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(vui & 0xF, vui >> 4) * d + m;
-}
-#endif
-
-#if defined(DATA_A_Q5_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-}
-#endif
-
-#if defined(DATA_A_Q5_1)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const float m = float(data_a[a_offset + ib].m);
- const uint uint_qh = data_a[a_offset + ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-}
-#endif
-
-#if defined(DATA_A_Q8_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
-}
-#endif
-
-#if defined(DATA_A_IQ4_NL)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
- const float d = float(data_a[a_offset + ib].d);
- const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
-}
-#endif
+++ /dev/null
-#extension GL_EXT_control_flow_attributes : require
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint M;
- uint K;
- uint stride_a;
- uint stride_b;
- uint nel;
-} p;
-
-#include "types.comp"
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint q_idx = 8*il;
- const uint b_idx = 1024*i + 32*ir + q_idx;
-
- const float d = float(data_a[ib].d);
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
- data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint ip = tid / 32;
- const uint il = tid - 32 * ip;
- const uint is = 8 * ip + il / 16;
-
- const uint y_idx = i * QUANT_K + 128 * ip + il;
-
- const uint ql_idx = 32 * ip + il;
- const uint8_t qs = data_a[i].qs[32 * ip + il];
-
- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
- data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
- data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
- data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
- data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint r = gl_LocalInvocationID.x / 4;
- const uint tid = r / 2;
- const uint is0 = r % 2;
- const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
- const uint n = tid / 4;
- const uint j = tid - 4*n;
-
- const uint8_t m = uint8_t(1 << (4*n + j));
- const uint is = 8*n + 2*j + is0;
- const uint shift = 2*j;
-
- const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
- is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
- is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
- (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
- const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
- const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
-
- const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
- const uint qs_idx = 32*n;
-
- for (uint l = l0; l < l0 + 4; ++l) {
- data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
- }
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint q_idx = 8*il;
- const uint b_idx = 1024*i + 32*ir + q_idx;
-
- const float d = float(data_a[ib].d);
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
- data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f));
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
- data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint il = tid / 8;
- const uint ir = tid % 8;
- const uint is = 2 * il;
- const uint n = 4;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
- const uint y_idx = i * QUANT_K + 64 * il + n * ir;
- const uint qs_idx = 32*il + n * ir;
-
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
- const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
- const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
-
- [[unroll]] for (uint l = 0; l < n; ++l) {
- data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
- data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
- }
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- const uint iqs = q_idx + l;
- const uint vui = uint(data_a[ib].qs[iqs]);
- data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
- data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 8*il;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint qh = data_a[ib].qh;
-
- const uint q_idx = 8*il;
-
- [[unroll]] for (uint l = 0; l < 8; ++l) {
- const uint iqs = q_idx + l;
- const uint vui = uint(data_a[ib].qs[iqs]);
- data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
- data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
-
- const uint tid = gl_LocalInvocationID.x;
- const uint il = tid / 16;
- const uint ir = tid % 16;
- const uint is = 2 * il;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
- const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
- const uint qs_idx = 32*il + 2 * ir;
- const uint qh_idx = 2 * ir;
-
- uint8_t sc;
- uint8_t m;
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is] & 63);
- m = uint8_t(data_a[i].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
- }
- const FLOAT_TYPE d1 = dall * sc;
- const FLOAT_TYPE m1 = dmin * m;
-
- if (is < 4) {
- sc = uint8_t(data_a[i].scales[is + 1] & 63);
- m = uint8_t(data_a[i].scales[is + 5] & 63);
- } else {
- sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
- m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
- }
- const FLOAT_TYPE d2 = dall * sc;
- const FLOAT_TYPE m2 = dmin * m;
-
- const uint8_t hm1 = uint8_t(1 << (2 * il ));
- const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
- data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
- data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
- data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
- const uint i = gl_WorkGroupID.x * 256 + wgy;
- if (i >= p.M * p.K / QUANT_K) {
- return;
- }
- const uint tid = gl_LocalInvocationID.x;
- const uint ip = tid / 32;
- const uint il = tid - 32 * ip;
- const uint is = 8 * ip + il / 16;
-
- const uint y_idx = i * QUANT_K + 128 * ip + il;
-
- const uint ql_idx = 64 * ip + il;
- const uint8_t qh = data_a[i].qh[32 * ip + il];
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
-
- data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
- data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
- data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
- data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
- }
-}
+++ /dev/null
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
- const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
- const uint tid = gl_LocalInvocationID.x % 64;
- const uint il = tid/32;
- const uint ir = tid%32;
- const uint ib = 32*i + ir;
- if (ib >= p.nel / 32) {
- return;
- }
-
- const uint b_idx = 1024*i + 32*ir + 16*il;
-
- const float d = float(data_a[ib].d);
-
- const uint q_idx = 16*il;
-
- [[unroll]] for (uint l = 0; l < 16; l += 2) {
- data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
- data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
- }
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_control_flow_attributes : enable
-
-layout (push_constant) uniform parameter
-{
- uint ncols;
- uint rows_per_channel;
- uint n_past;
-} p;
-
-#include "types.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint col = gl_GlobalInvocationID.y;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- const uint i = row*p.ncols + col;
- if (col > p.n_past + row % p.rows_per_channel) {
- data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
- } else {
- data_d[i] = D_TYPE(data_a[i]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const float GELU_COEF_A = 0.044715f;
- const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float xi = float(data_a[i]);
- const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
- data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const float GELU_QUICK_COEF = -1.702f;
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float x = float(data_a[i]);
- data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
-}
+++ /dev/null
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint ne;
- uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
- uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
- uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
- uint d_offset;
- float param1; float param2; int param3;
-} p;
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-uint get_idx() {
- return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-}
-
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint src1_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-
- return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
-}
-
-uint dst_idx(uint idx) {
- const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
- const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
- const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
- const uint i22_offset = i22*p.ne21*p.ne20;
- const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
- const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
- return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
-}
+++ /dev/null
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint KX;
- uint KY;
- float param1;
- float param2;
-} p;
+++ /dev/null
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_control_flow_attributes : require
-
-layout (push_constant) uniform parameter
-{
- uint ne;
- uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
- uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
- uint d_offset;
- float param1; float param2;
-} p;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-uint get_idx() {
- return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-}
-
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint dst_idx(uint idx) {
- const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
- const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
- const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
- const uint i12_offset = i12*p.ne11*p.ne10;
- const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
- const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
- return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint i00 = gl_GlobalInvocationID.x;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
- if (i00 >= p.ne00) {
- return;
- }
-
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
-#else
- data_d[d_offset + i00] = data_a[a_offset + i00];
-#endif
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-#include "dequant_funcs.comp"
-
-void main() {
- const uint i00 = (gl_GlobalInvocationID.x)*2;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
- if (i00 >= p.ne00) {
- return;
- }
-
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
- const uint ib = a_offset + i00/QUANT_K; // block index
- const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
- const uint iybs = i00 - i00%QUANT_K; // dst block start index
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
- vec2 v = dequantize(ib, iqs, 0);
-
- data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
- data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared float tmp[BLOCK_SIZE];
-
-void main() {
- const uint group_size = p.KX;
- const float eps = p.param1;
-
- const uint tid = gl_LocalInvocationID.x;
- const uint start = gl_WorkGroupID.x * group_size + tid;
- const uint end = start + group_size;
-
- tmp[tid] = 0.0f;
-
- // Calculate mean
- [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
- tmp[tid] += float(data_a[col]);
- }
-
- // tmp up partial tmps and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- const float mean = tmp[0] / group_size;
- barrier();
- tmp[tid] = 0.0f;
-
- // Calculate variance
- [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
- const float xi = float(data_a[col]) - mean;
- data_d[col] = D_TYPE(xi);
- tmp[tid] += xi * xi;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- const float variance = tmp[0] / group_size;
- const float scale = inversesqrt(variance + eps);
-
- [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
- data_d[col] *= D_TYPE(scale);
- }
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint batch_offset; uint offset_delta;
- uint IC;
- uint IW; uint IH;
- uint OW; uint OH;
- uint KW; uint KH;
- uint pelements;
- uint CHW;
- int s0; int s1;
- int p0; int p1;
- int d0; int d1;
-} p;
-
-#include "types.comp"
-
-#define BLOCK_SIZE 256
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.x;
- if (i >= p.pelements) {
- return;
- }
-
- const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
- const uint kx = i / ksize;
- const uint kd = kx * ksize;
- const uint ky = (i - kd) / p.OW;
- const uint ix = i % p.OW;
-
- const uint oh = gl_GlobalInvocationID.y;
- const uint batch = gl_GlobalInvocationID.z / p.IC;
- const uint ic = gl_GlobalInvocationID.z % p.IC;
-
- const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
- const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
-
- const uint offset_dst =
- ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
- (ic * (p.KW * p.KH) + ky * p.KW + kx);
-
- if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
- data_d[offset_dst] = D_TYPE(0.0f);
- } else {
- const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
- data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float val = float(data_a[i]);
- data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ne;
- uint k_num;
-} p;
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
-
- if (idx >= p.ne) {
- return;
- }
-
- float result = 0.0f;
-
- [[unroll]] for (uint i = 0; i < p.k_num; i++) {
- result += data_a[i * p.ne + idx];
- }
-
- data_d[idx] = result;
-}
+++ /dev/null
-#version 450
-
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
- const uint tid = gl_LocalInvocationID.x;
-
- // There are not enough cols to use all threads
- if (tid >= p.ncols) {
- return;
- }
-
- const uint block_size = min(p.ncols, BLOCK_SIZE);
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
- tmp[tid] = FLOAT_TYPE(0.0f);
-
- [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
- const uint col = i*block_size + 2*tid;
- const uint ib = (row*p.ncols + col)/QUANT_K; // block index
- const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
- const uint iybs = col - col%QUANT_K; // y block start index
-
- vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
-
- // matrix multiplication
- tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_8bit_storage : require
-
-#define K_QUANTS_PER_ITERATION 2
-
-#ifdef MUL_MAT_ID
-#define EXPERT_COUNT 8
-#endif
-
-#include "types.comp"
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-#include "dequant_funcs.comp"
-
-layout (push_constant) uniform parameter
-{
- uint ncols;
- uint stride_a;
- uint stride_b;
- uint stride_d;
-
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
- uint nei0;
- uint ne11;
-#else
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
-#endif
-} p;
-
-void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
-#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.y;
-#else
- const uint batch_idx = gl_GlobalInvocationID.y;
-#endif
-
-#ifndef MUL_MAT_ID
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-#else
- const uint expert_id = data_ids[expert_idx];
-#endif
-
- a_offset =
-#ifdef MUL_MAT_ID
- expert_id * p.batch_stride_a;
-#else
- batch_idx_a * p.batch_stride_a;
-#endif
- b_offset =
-#ifdef MUL_MAT_ID
- (expert_idx % p.ne11) * p.stride_b;
-#else
- batch_idx * p.batch_stride_b;
-#endif
- d_offset =
-#ifdef MUL_MAT_ID
- expert_idx * p.stride_d;
-#else
- batch_idx * p.batch_stride_d;
-#endif
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
- uint ncols_x;
- uint nrows_x;
- uint row_stride_x;
- uint channel_stride_x;
- uint channel_x_divisor;
- uint b_offset;
- uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / p.channel_x_divisor;
-
- const uint nrows_y = p.ncols_x;
- const uint nrows_dst = p.nrows_x;
- const uint row_dst = row_x;
-
- const uint idst = channel*nrows_dst + row_dst;
-
- tmp[tid] = 0.0f;
-
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
-
- if (col_x >= p.ncols_x) {
- break;
- }
-
- const uint row_y = col_x;
-
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
- const uint iy = channel*nrows_y + row_y;
-
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- if (tid == 0) {
- dst[idst] = tmp[0];
- }
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
- uint ncols_x;
- uint nrows_x;
- uint nchannels_x;
- uint nchannels_y;
- uint b_offset;
- uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
-
- const uint nrows_y = p.ncols_x;
- const uint nrows_dst = p.nrows_x;
- const uint row_dst = row_x;
-
- tmp[tid] = FLOAT_TYPE(0.0f);
-
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
-
- if (col_x >= p.ncols_x) {
- break;
- }
-
- // x is transposed and permuted
- const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
- const uint row_y = col_x;
-
- // y is not transposed but permuted
- const uint iy = channel*nrows_y + row_y;
-
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
- }
-
- // dst is not transposed and not permuted
- const uint idst = channel*nrows_dst + row_dst;
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
-
- if (tid == 0) {
- dst[idst] = tmp[0];
- }
-}
+++ /dev/null
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint s_offset = 8*v_im;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1))))))));
- sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2))))))));
- }
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx]));
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
- const uint8_t m = uint8_t(1 << (4 * v_im));
-
- const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- const uint s_shift = 4 * v_im;
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
- sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
- }
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]);
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
-
- const uint il = tid/step; // 0...3
- const uint ir = tid - step*il; // 0...7 or 0...3
- const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
-
- const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const uint v_in = il % 2;
-
- const uint l0 = n * (2 * ir + v_in); // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 64*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-#if K_QUANTS_PER_ITERATION == 2
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
-
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3)));
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11)));
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
- const FLOAT_TYPE smin =
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
-#else
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
-
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
- const FLOAT_TYPE smin =
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
- + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
-
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
- sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
- fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
-#endif
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
-
- const uint il = tid/4; // 0...3
- const uint ir = tid - 4*il; // 0...7 or 0...3
-
- const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const uint v_in = il % 2;
-
- const uint l0 = 4*ir + 2*v_in; // 0...15
- const uint q_offset = 32*v_im + l0;
- const uint y_offset = 64*v_im + l0;
-
- const uint8_t hm1 = uint8_t(1 << (2*v_im));
- const uint8_t hm2 = uint8_t(hm1 << 4);
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
- const uint y1_idx = i * QUANT_K + y_offset;
- const uint y2_idx = y1_idx + 128;
-
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
-
- const FLOAT_TYPE sx =
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
- FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
- const FLOAT_TYPE sy =
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
- FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
- const FLOAT_TYPE sz =
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
- FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
- const FLOAT_TYPE sw =
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
- FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
- const FLOAT_TYPE smin =
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
- (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
- uint a_offset, b_offset, d_offset;
- get_offsets(a_offset, b_offset, d_offset);
-
- const uint num_blocks_per_row = p.ncols / QUANT_K;
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
- const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
-
- const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
-
- const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const uint v_in = tid - step*v_im; // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
- const uint l0 = v_in; // 0...15
- const uint is = 0;
-#else
- const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
- const uint is = v_in / 4;
-#endif
-
- const uint ql_offset = 64*v_im + l0;
- const uint qh_offset = 32*v_im + l0;
- const uint s_offset = 8*v_im + is;
- const uint y_offset = 128*v_im + l0;
-
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint y_idx = i * QUANT_K + y_offset;
-
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-#if K_QUANTS_PER_ITERATION == 1
- const uint tmp_idx = 16 * ix + tid;
- tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
-#else
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 4; ++l) {
- sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
- }
- tmp[16 * ix + tid] += sum;
-#endif
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
- if (tid < s) {
- tmp[tid] += tmp[tid + s];
- }
- barrier();
- }
- if (tid == 0) {
- data_d[d_offset + row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
-
-#ifdef MUL_MAT_ID
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#endif
-
-#include "types.comp"
-
-#ifndef LOAD_VEC_A
-#define LOAD_VEC_A 1
-#endif
-#ifndef LOAD_VEC_B
-#define LOAD_VEC_B 1
-#endif
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
- uint M;
- uint N;
- uint K;
- uint stride_a;
- uint stride_b;
- uint stride_d;
-
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
- uint nei0;
- uint nei1;
- uint nbi1;
- uint ne11;
-#else
- uint k_split;
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
-#endif
-} p;
-
-layout (constant_id = 1) const uint BM = 64;
-layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
-layout (constant_id = 4) const uint WM = 32;
-layout (constant_id = 5) const uint WN = 32;
-layout (constant_id = 6) const uint WMITER = 2;
-layout (constant_id = 7) const uint TM = 4;
-layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
-
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
-
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[3072];
-#endif
-
-void main() {
-#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
-#else
- const uint batch_idx = gl_GlobalInvocationID.z;
-
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
-
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
-
- const uint batch_idx_a = i03 * p.ne02 + i02;
-#endif
-
- const uint blocks_m = (p.M + BM - 1) / BM;
- const uint ir = gl_WorkGroupID.x % blocks_m;
- const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
-
- const uint warp_i = gl_LocalInvocationID.x / WARP;
- const uint warp_r = warp_i % (BM / WM);
- const uint warp_c = warp_i / (BM / WM);
-
- const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
- const uint WSUBM = WM / WMITER;
- const uint WSUBN = WN / WNITER;
-
- const uint tiw = gl_LocalInvocationID.x % WARP;
- const uint tiwr = tiw % (WSUBM / TM);
- const uint tiwc = tiw / (WSUBM / TM);
-
- const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
- const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
- const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
- const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
-
- const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
- const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
-
-#ifdef MUL_MAT_ID
- uint _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
- if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
- _ne1++;
- }
- }
- }
-
- barrier();
-
- // Workgroup has no work
- if (ic * BN >= _ne1) return;
-#endif
-
-#ifdef MUL_MAT_ID
- const uint start_k = 0;
- const uint end_k = p.K;
-#else
- const uint start_k = ik * p.k_split;
- const uint end_k = min(p.K, (ik + 1) * p.k_split);
-#endif
-
- uint pos_a = (
-#ifdef MUL_MAT_ID
- expert_idx * p.batch_stride_a +
-#else
- batch_idx_a * p.batch_stride_a +
-#endif
- ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
-#ifdef MUL_MAT_ID
- uint pos_b = 0;
-#else
- uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
-#endif
-
- float sums[WMITER * TM * WNITER * TN];
- FLOAT_TYPE cache_a[WMITER * TM];
- FLOAT_TYPE cache_b[WNITER * TN];
-
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = 0.0f;
- }
-
- [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
- [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
-
-#if defined(DATA_A_F32) || defined(DATA_A_F16)
-#if LOAD_VEC_A == 8
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
- buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
- buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
- buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
- buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
-#elif LOAD_VEC_A == 4
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
-#else
- if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
- } else {
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
- }
-#endif
-#elif defined(DATA_A_Q4_0)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q4_1)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q5_0)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q5_1)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const float m = float(data_a[ib].m);
- const uint uint_qh = data_a[ib].qh;
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q8_0)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 16;
- const uint iqs = (idx & 0xF) * 2;
-
- const float d = float(data_a[ib].d);
- const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q2_K)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
- const uint scalesi = iqs / 8; // 0..15
- const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
-
- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
- const uint scales = data_a[ib].scales[scalesi];
- const vec2 d = vec2(data_a[ib].d);
-
- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q3_K)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 64; // 0,1
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
- const uint hmi = (iqs % 16) * 2; // 0,2,4..30
- const uint j = (iqs % 64) / 4; // 0..3
- const uint is = iqs / 8; // 0..15
- const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
- const uint qsshift = halfsplit * 2; // 0,2,4,6
- const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
-
- const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
- is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
- is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
- (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
- const float dl = float(data_a[ib].d) * float(us - 32);
-
- buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
- buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
-#elif defined(DATA_A_Q4_K)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 32; // 0,1,2,3
- const uint b = (iqs % 32) / 16; // 0,1
- const uint is = 2 * n + b; // 0..7
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
-
- const vec2 loadd = vec2(data_a[ib].d);
-
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
- const float d = loadd.x * sc;
- const float m = -loadd.y * mbyte;
-
- buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
- buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
-#elif defined(DATA_A_Q5_K)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 32; // 0,1,2,3
- const uint b = (iqs % 32) / 16; // 0,1
- const uint is = 2 * n + b; // 0..7
- const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const uint qhi = (iqs % 16) * 2; // 0,2,4..30
-
- const uint8_t hm = uint8_t(1 << (iqs / 16));
-
- const vec2 loadd = vec2(data_a[ib].d);
-
- uint8_t sc;
- uint8_t mbyte;
- if (is < 4) {
- sc = uint8_t(data_a[ib].scales[is ] & 63);
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
- } else {
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
- }
- const float d = loadd.x * sc;
- const float m = -loadd.y * mbyte;
-
- buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
- buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
-#elif defined(DATA_A_Q6_K)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
- const uint ib = idx / 128; // 2 values per idx
- const uint iqs = idx % 128; // 0..127
-
- const uint n = iqs / 64; // 0,1
- const uint b = (iqs % 64) / 32; // 0,1
- const uint is_b = (iqs % 16) / 8; // 0,1
- const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uint is = 8 * n + qhshift + is_b; // 0..15
- const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
- const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
-
- const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
-
- buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
- buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
-#elif defined(DATA_A_IQ4_NL)
- const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
- const uint ib = idx / 16;
- const uint iqs = idx & 0xF;
-
- const float d = float(data_a[ib].d);
- const uint vui = uint(data_a[ib].qs[iqs]);
- const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
-
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#endif
- }
- [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
-#if LOAD_VEC_B == 8
-#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
- const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
-#else
- const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-#endif
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
- buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
- buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
- buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
- buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
-#elif LOAD_VEC_B == 4
-#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
- const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
-#else
- const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-#endif
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
-#elif !MUL_MAT_ID
- if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
- } else {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
- }
-#else
- const uint row_i = ic * BN + loadc_b + l;
- if (row_i < _ne1) {
- const u16vec2 row_idx = row_ids[row_i];
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
- } else {
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
- }
-#endif
- }
-
- barrier();
-
- pos_a += BK / LOAD_VEC_A;
- pos_b += BK / LOAD_VEC_B;
-
- for (uint i = 0; i < BK; i++) {
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint j = 0; j < TM; j++) {
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
- }
- }
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint j = 0; j < TN; j++) {
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
- }
- }
-
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
- }
- }
- }
- }
- }
-
- barrier();
- }
-
- const uint dr = ir * BM + warp_r * WM;
- const uint dc = ic * BN + warp_c * WN;
-
-#ifndef MUL_MAT_ID
- const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
-#endif
-
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-
- const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
- const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-#ifdef MUL_MAT_ID
- const uint row_i = dc_warp + cc;
- if (row_i >= _ne1) break;
-
- const u16vec2 row_idx = row_ids[row_i];
-#endif
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-#ifdef MUL_MAT_ID
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
-#else
- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
- }
-#endif
- }
- }
- }
- }
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared vec2 sum[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
-
- sum[tid] = vec2(0.0f, 0.0f);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const float xi = float(data_a[row*p.KX + col]);
- sum[tid].x += xi;
- sum[tid].y += xi * xi;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sum[tid] += sum[tid + s];
- }
- barrier();
- }
-
- const float mean = sum[0].x / p.KX;
- const float var = sum[0].y / p.KX - mean * mean;
- const float inv_std = inversesqrt(var + p.param1);
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
- }
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (idx >= p.ne) {
- return;
- }
-
- const uint i3 = idx / (p.ne12*p.ne11*p.ne10);
- const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
- const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);
- const uint i2_offset = i2*p.ne11*p.ne10;
- const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
- const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
-
- const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
- const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
-
- const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
-
- data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(push_constant) uniform parameter {
- uint IW; uint IH;
- uint OW; uint OH;
- uint OC;
- uint pelements;
- uint op;
- int k0; int k1;
- int s0; int s1;
- int p0; int p1;
-} p;
-
-#define BLOCK_SIZE 512
-#define FLT_MAX 3.402823466e+38F
-#define OP_POOL_MAX 0u
-#define OP_POOL_AVG 1u
-
-layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
- if (idx >= p.pelements) {
- return;
- }
-
- const uint O_HW = p.OW * p.OH;
-
- const uint nc = idx / O_HW;
- const uint cur_oh = (idx % O_HW) / p.OW;
- const uint cur_ow = (idx % O_HW) % p.OW;
-
- const int start_h = int(cur_oh) * p.s0 - p.p0;
- const uint bh = max(start_h, 0);
- const uint eh = min(start_h + p.k0, p.IH);
-
- const int start_w = int(cur_ow) * p.s1 - p.p1;
- const uint bw = max(start_w, 0);
- const uint ew = min(start_w + p.k1, p.IW);
-
- const float scale = 1.0 / float(p.k0 * p.k1);
- float res;
-
- if (p.op == OP_POOL_AVG) {
- res = 0.0;
- } else if (p.op == OP_POOL_MAX) {
- res = -FLT_MAX;
- } else {
- return;
- }
-
- #pragma unroll
- for (uint i = bh; i < eh; i++) {
- #pragma unroll
- for (uint j = bw; j < ew; j++) {
- const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
-
- if (p.op == OP_POOL_AVG) {
- res += cur * scale;
- } else if (p.op == OP_POOL_MAX) {
- res = max(res, cur);
- }
- }
- }
-
- data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- data_d[i] = max(float(data_a[i]), 0);
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-uint src0_idx_mod(uint idx) {
- const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
- const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
- const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
- const uint i12_offset = i12*p.ne11*p.ne10;
- const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
- const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
- return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
-}
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE sum[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint tid = gl_LocalInvocationID.x;
-
- sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
- sum[tid] += xi * xi;
- }
-
- // sum up partial sums and write back result
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sum[tid] += sum[tid + s];
- }
- barrier();
- }
-
- const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
- const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
-
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
- }
-}
+++ /dev/null
-#include "types.comp"
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_pos[];};
-layout (binding = 2) readonly buffer Z {float data_ff[];};
-layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
- uint ncols;
- uint n_dims;
- float freq_scale;
- uint p_delta_rows;
- float freq_base;
- float ext_factor;
- float attn_factor;
- float corr_dims[2];
- float theta_scale;
- uint has_ff;
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
- float mscale = p.attn_factor;
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = p.freq_scale * theta_extrap;
- float theta = theta_interp;
- if (p.ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
- }
- cos_theta = cos(theta) * mscale;
- sin_theta = sin(theta) * mscale;
-}
+++ /dev/null
-#version 450
-
-#include "rope_head.comp"
-
-void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- if (col >= p.n_dims) {
- const uint i = row*p.ncols + col;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
- const uint i = row*p.ncols + col/2;
- const uint i2 = row/p.p_delta_rows;
-
- const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
-
- const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
-
- float cos_theta, sin_theta;
- rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
-
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + p.n_dims/2]);
-
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
+++ /dev/null
-#version 450
-
-#include "rope_head.comp"
-
-void main() {
- const uint col = gl_GlobalInvocationID.y * 2;
- const uint row = gl_GlobalInvocationID.x;
-
- if (col >= p.ncols) {
- return;
- }
-
- if (col >= p.n_dims) {
- const uint i = row*p.ncols + col;
-
- data_d[i + 0] = data_a[i + 0];
- data_d[i + 1] = data_a[i + 1];
-
- return;
- }
-
- const uint i = row*p.ncols + col;
- const uint i2 = row/p.p_delta_rows;
-
- const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
-
- const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
-
- float cos_theta, sin_theta;
- rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
-
- const float x0 = float(data_a[i + 0]);
- const float x1 = float(data_a[i + 1]);
-
- data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
- data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-const uint num_threads = 128;
-
-layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- uint idx = get_idx();
-
- // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
- const uint num_iter = 4;
-
- [[unroll]] for (uint i = 0; i < num_iter; ++i) {
- if (idx >= p.ne) {
- continue;
- }
-
- data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
- idx += num_threads;
- }
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- const float xi = float(data_a[i]);
- data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint KX;
- uint KY;
- float scale;
- float max_bias;
- float m0;
- float m1;
- uint n_head_log2;
-} p;
-
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE vals[BLOCK_SIZE];
-
-void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint rowy = rowx % p.KY;
-
- float slope = 1.0f;
-
- // ALiBi
- if (p.max_bias > 0.0f) {
- const uint h = rowx/p.KY; // head index
-
- const float base = h < p.n_head_log2 ? p.m0 : p.m1;
- const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- // Find max
- FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
-
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
- const uint col = col0 + tid;
-
- if (col >= p.KX) {
- break;
- }
-
- max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
- }
- vals[tid] = max_val;
-
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- vals[tid] = max(vals[tid], vals[tid + s]);
- }
- barrier();
- }
-
- max_val = vals[0];
- barrier();
-
- // Sum up values
- vals[tid] = FLOAT_TYPE(0.0f);
-
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
- const uint col = col0 + tid;
-
- if (col >= p.KX) {
- break;
- }
-
- const uint i = rowx * p.KX + col;
- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
- vals[tid] += val;
- data_d[i] = D_TYPE(val);
- }
-
- barrier();
- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
- if (tid < s) {
- vals[tid] += vals[tid + s];
- }
- barrier();
- }
-
- const D_TYPE divisor = D_TYPE(vals[0]);
-
- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
- const uint col = col0 + tid;
-
- if (col >= p.KX) {
- break;
- }
-
- data_d[rowx*p.KX + col] /= divisor;
- }
-}
+++ /dev/null
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint idx = get_idx();
-
- if (idx >= p.ne) {
- return;
- }
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
- const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
- const uint col = gl_LocalInvocationID.x;
-
- tmp[col] = FLOAT_TYPE(0.0f);
-
- for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
- tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
- }
-
- barrier();
- [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
- if (col < s) {
- tmp[col] += tmp[col + s];
- }
- barrier();
- }
-
- if (col == 0) {
- data_d[row] = D_TYPE(tmp[0]);
- }
-}
+++ /dev/null
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (i >= p.KX) {
- return;
- }
-
- data_d[i] = D_TYPE(tanh(data_a[i]));
-}
+++ /dev/null
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
- uint nb1;
- uint dim;
- uint max_period;
-} p;
-
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 256
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint i = gl_WorkGroupID.y;
- const uint j = gl_GlobalInvocationID.x;
- const uint d_offset = i * p.nb1;
-
- if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
- data_d[d_offset + p.dim] = 0.f;
- }
-
- const uint half_dim = p.dim / 2;
- if (j >= half_dim) {
- return;
- }
-
- const float timestep = float(data_a[i]);
- const float freq = float(exp(-log(p.max_period) * j / half_dim));
- const float arg = timestep * freq;
- data_d[d_offset + j] = D_TYPE(cos(arg));
- data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));
-}
+++ /dev/null
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#endif
-
-#if defined(DATA_A_F32)
-#define QUANT_K 1
-#define QUANT_R 1
-
-#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
-#define A_TYPE float
-#elif LOAD_VEC_A == 4
-#define A_TYPE vec4
-#elif LOAD_VEC_A == 8
-#define A_TYPE mat2x4
-#endif
-#endif
-
-#if defined(DATA_A_F16)
-#define QUANT_K 1
-#define QUANT_R 1
-
-#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
-#define A_TYPE float16_t
-#elif LOAD_VEC_A == 4
-#define A_TYPE f16vec4
-#elif LOAD_VEC_A == 8
-#define A_TYPE f16mat2x4
-#endif
-#endif
-
-#if defined(DATA_A_Q4_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_0
-{
- float16_t d;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_0
-#endif
-
-#if defined(DATA_A_Q4_1)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_1
-{
- float16_t d;
- float16_t m;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_1
-#endif
-
-#if defined(DATA_A_Q5_0)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_0
-{
- float16_t d;
- uint16_t qh[2];
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_0
-#endif
-
-#if defined(DATA_A_Q5_1)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_1
-{
- float16_t d;
- float16_t m;
- uint qh;
- uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_1
-#endif
-
-#if defined(DATA_A_Q8_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 1
-
-struct block_q8_0
-{
- float16_t d;
- int8_t qs[32];
-};
-
-#define A_TYPE block_q8_0
-#endif
-
-// K-quants
-#if defined(DATA_A_Q2_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q2_K
-{
- uint8_t scales[QUANT_K/16];
- uint8_t qs[QUANT_K/4];
- f16vec2 d;
-};
-
-#define A_TYPE block_q2_K
-#endif
-
-#if defined(DATA_A_Q3_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q3_K
-{
- uint8_t hmask[QUANT_K/8];
- uint8_t qs[QUANT_K/4];
- uint8_t scales[12];
- float16_t d;
-};
-
-#define A_TYPE block_q3_K
-#endif
-
-#if defined(DATA_A_Q4_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q4_K
-{
- f16vec2 d;
- uint8_t scales[3*QUANT_K/64];
- uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q4_K
-#endif
-
-#if defined(DATA_A_Q5_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q5_K
-{
- f16vec2 d;
- uint8_t scales[12];
- uint8_t qh[QUANT_K/8];
- uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q5_K
-#endif
-
-#if defined(DATA_A_Q6_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q6_K
-{
- uint8_t ql[QUANT_K/2];
- uint8_t qh[QUANT_K/4];
- int8_t scales[QUANT_K/16];
- float16_t d;
-};
-
-#define A_TYPE block_q6_K
-#endif
-
-// IQuants
-
-#if defined(DATA_A_IQ4_NL)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_iq4_nl
-{
- float16_t d;
- uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_iq4_nl
-
-const int8_t kvalues_iq4nl[16] = {
- int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
- int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
-};
-#endif
+++ /dev/null
-#version 450
-
-layout (push_constant) uniform parameter
-{
- uint ne; uint d_offset;
- uint nb00; uint nb01; uint nb02; uint nb03;
- uint ne10; uint ne11; uint ne12; uint ne13;
- float sf0; float sf1; float sf2; float sf3;
-} p;
-
-#include "types.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
- if (idx >= p.ne) {
- return;
- }
-
- const uint i10 = idx % p.ne10;
- const uint i11 = (idx / p.ne10) % p.ne11;
- const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
- const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
-
- const uint i00 = uint(i10 / p.sf0);
- const uint i01 = uint(i11 / p.sf1);
- const uint i02 = uint(i12 / p.sf2);
- const uint i03 = uint(i13 / p.sf3);
-
- data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
-}
+++ /dev/null
-
-
-#include <iostream>
-#include <fstream>
-#include <sstream>
-#include <string>
-#include <stdexcept>
-#include <array>
-#include <vector>
-#include <map>
-#include <thread>
-#include <mutex>
-#include <future>
-#include <queue>
-#include <condition_variable>
-#include <cstdio>
-#include <cstring>
-#include <cstdlib>
-#include <cassert>
-#include <sys/stat.h>
-#include <sys/types.h>
-
-#ifdef _WIN32
- #include <windows.h>
- #include <direct.h> // For _mkdir on Windows
- #include <algorithm> // For std::replace on w64devkit
-#else
- #include <unistd.h>
- #include <sys/wait.h>
- #include <fcntl.h>
-#endif
-
-#define ASYNCIO_CONCURRENCY 64
-
-std::mutex lock;
-std::vector<std::pair<std::string, std::string>> shader_fnames;
-
-std::string GLSLC = "glslc";
-std::string input_dir = "vulkan-shaders";
-std::string output_dir = "/tmp";
-std::string target_hpp = "ggml-vulkan-shaders.hpp";
-std::string target_cpp = "ggml-vulkan-shaders.cpp";
-bool no_clean = false;
-
-const std::vector<std::string> type_names = {
- "f32",
- "f16",
- "q4_0",
- "q4_1",
- "q5_0",
- "q5_1",
- "q8_0",
- "q2_k",
- "q3_k",
- "q4_k",
- "q5_k",
- "q6_k",
- "iq4_nl"
-};
-
-void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
-#ifdef _WIN32
- HANDLE stdout_read, stdout_write;
- HANDLE stderr_read, stderr_write;
- SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
-
- if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
- !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
- throw std::runtime_error("Failed to create stdout pipe");
- }
-
- if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
- !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
- throw std::runtime_error("Failed to create stderr pipe");
- }
-
- PROCESS_INFORMATION pi;
- STARTUPINFOA si = { sizeof(STARTUPINFOA) };
- si.dwFlags = STARTF_USESTDHANDLES;
- si.hStdOutput = stdout_write;
- si.hStdError = stderr_write;
-
- std::vector<char> cmd(command.begin(), command.end());
- cmd.push_back('\0');
-
- if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
- throw std::runtime_error("Failed to create process");
- }
-
- CloseHandle(stdout_write);
- CloseHandle(stderr_write);
-
- std::array<char, 128> buffer;
- DWORD bytes_read;
-
- while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
- stdout_str.append(buffer.data(), bytes_read);
- }
-
- while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
- stderr_str.append(buffer.data(), bytes_read);
- }
-
- CloseHandle(stdout_read);
- CloseHandle(stderr_read);
- WaitForSingleObject(pi.hProcess, INFINITE);
- CloseHandle(pi.hProcess);
- CloseHandle(pi.hThread);
-#else
-int stdout_pipe[2];
- int stderr_pipe[2];
-
- if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
- throw std::runtime_error("Failed to create pipes");
- }
-
- pid_t pid = fork();
- if (pid < 0) {
- throw std::runtime_error("Failed to fork process");
- }
-
- if (pid == 0) {
- close(stdout_pipe[0]);
- close(stderr_pipe[0]);
- dup2(stdout_pipe[1], STDOUT_FILENO);
- dup2(stderr_pipe[1], STDERR_FILENO);
- close(stdout_pipe[1]);
- close(stderr_pipe[1]);
- execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
- _exit(EXIT_FAILURE);
- } else {
- close(stdout_pipe[1]);
- close(stderr_pipe[1]);
-
- std::array<char, 128> buffer;
- ssize_t bytes_read;
-
- while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
- stdout_str.append(buffer.data(), bytes_read);
- }
-
- while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
- stderr_str.append(buffer.data(), bytes_read);
- }
-
- close(stdout_pipe[0]);
- close(stderr_pipe[0]);
- waitpid(pid, nullptr, 0);
- }
-#endif
-}
-
-bool directory_exists(const std::string& path) {
- struct stat info;
- if (stat(path.c_str(), &info) != 0) {
- return false; // Path doesn't exist or can't be accessed
- }
- return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
-}
-
-bool create_directory(const std::string& path) {
-#ifdef _WIN32
- return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
-#else
- return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
-#endif
-}
-
-std::string to_uppercase(const std::string& input) {
- std::string result = input;
- for (char& c : result) {
- c = std::toupper(c);
- }
- return result;
-}
-
-bool string_ends_with(const std::string& str, const std::string& suffix) {
- if (suffix.size() > str.size()) {
- return false;
- }
- return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
-}
-
-static const char path_separator = '/';
-
-std::string join_paths(const std::string& path1, const std::string& path2) {
- return path1 + path_separator + path2;
-}
-
-std::string basename(const std::string &path) {
- return path.substr(path.find_last_of("/\\") + 1);
-}
-
-// variables to track number of compiles in progress
-static uint32_t compile_count = 0;
-static std::mutex compile_count_mutex;
-static std::condition_variable compile_count_cond;
-
-void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
- std::string name = _name + (fp16 ? "" : "_fp32");
- std::string out_fname = join_paths(output_dir, name + ".spv");
- std::string in_path = join_paths(input_dir, in_fname);
-
- #ifdef _WIN32
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
- #else
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
- #endif
-
- #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
- cmd.push_back("-g");
- #endif
-
- for (const auto& define : defines) {
- cmd.push_back("-D" + define.first + "=" + define.second);
- }
-
- std::string command;
- for (const auto& part : cmd) {
- command += part + " ";
- }
-
- std::string stdout_str, stderr_str;
- try {
- // std::cout << "Executing command: ";
- // for (const auto& part : cmd) {
- // std::cout << part << " ";
- // }
- // std::cout << std::endl;
-
- execute_command(command, stdout_str, stderr_str);
- if (!stderr_str.empty()) {
- std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
- return;
- }
-
- std::lock_guard<std::mutex> guard(lock);
- shader_fnames.push_back(std::make_pair(name, out_fname));
- } catch (const std::exception& e) {
- std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
- }
- {
- std::lock_guard<std::mutex> guard(compile_count_mutex);
- assert(compile_count > 0);
- compile_count--;
- }
- compile_count_cond.notify_all();
-}
-
-std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
- std::map<std::string, std::string> result = a;
- result.insert(b.begin(), b.end());
- return result;
-}
-
-static std::vector<std::future<void>> compiles;
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
- {
- // wait until fewer than N compiles are in progress.
- // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
- uint32_t N = 16;
- std::unique_lock<std::mutex> guard(compile_count_mutex);
- while (compile_count >= N) {
- compile_count_cond.wait(guard);
- }
- compile_count++;
- }
- compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16));
-}
-
-void matmul_shaders(bool fp16, bool matmul_id) {
- std::string load_vec = fp16 ? "8" : "4";
- std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
- std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
-
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
- std::string shader_name = "matmul";
-
- if (matmul_id) {
- base_dict["MUL_MAT_ID"] = "1";
- shader_name = "matmul_id";
- }
-
- if (fp16) {
- base_dict["FLOAT16"] = "1";
- }
-
- // Shaders with f16 B_TYPE
- string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
-
- string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
-
- for (const auto& tname : type_names) {
- std::string data_a_key = "DATA_A_" + to_uppercase(tname);
- // For unaligned, load one at a time for f32/f16, or two at a time for quants
- std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
- // For aligned matmul loads
- std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
- }
-}
-
-void process_shaders() {
- std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
-
- for (const auto& fp16 : {false, true}) {
- matmul_shaders(fp16, false);
- matmul_shaders(fp16, true);
- }
-
- for (const auto& tname : type_names) {
- // mul mat vec
- std::string data_a_key = "DATA_A_" + to_uppercase(tname);
- std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
-
- string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
-
- string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
-
- // Dequant shaders
- if (tname != "f16") {
- string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
- }
-
- if (!string_ends_with(tname, "_k")) {
- shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
-
- if (tname == "f16") {
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- } else {
- string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
- }
- string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
- }
- }
-
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-
- // Norms
- string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
- string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
- string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
- string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
-
- string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
-
- string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
- string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
- string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
-
- string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
- string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
-
- string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-
- string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-
- string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
-
- string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
- string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
-
- string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
- string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
- for (auto &c : compiles) {
- c.wait();
- }
-}
-
-void write_output_files() {
- FILE* hdr = fopen(target_hpp.c_str(), "w");
- FILE* src = fopen(target_cpp.c_str(), "w");
-
- fprintf(hdr, "#include <cstdint>\n\n");
- fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
-
- for (const auto& pair : shader_fnames) {
- const std::string& name = pair.first;
- #ifdef _WIN32
- std::string path = pair.second;
- std::replace(path.begin(), path.end(), '/', '\\' );
- #else
- const std::string& path = pair.second;
- #endif
-
- FILE* spv = fopen(path.c_str(), "rb");
- if (!spv) {
- std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
- continue;
- }
-
- fseek(spv, 0, SEEK_END);
- size_t size = ftell(spv);
- fseek(spv, 0, SEEK_SET);
-
- std::vector<unsigned char> data(size);
- size_t read_size = fread(data.data(), 1, size, spv);
- fclose(spv);
- if (read_size != size) {
- std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
- continue;
- }
-
- fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
- fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
-
- fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
- for (size_t i = 0; i < size; ++i) {
- fprintf(src, "0x%02x,", data[i]);
- if ((i + 1) % 12 == 0) fprintf(src, "\n");
- }
- fprintf(src, "\n};\n\n");
-
- if (!no_clean) {
- std::remove(path.c_str());
- }
- }
-
- fclose(hdr);
- fclose(src);
-}
-
-int main(int argc, char** argv) {
- std::map<std::string, std::string> args;
- for (int i = 1; i < argc; i += 2) {
- if (i + 1 < argc) {
- args[argv[i]] = argv[i + 1];
- }
- }
-
- if (args.find("--glslc") != args.end()) {
- GLSLC = args["--glslc"]; // Path to glslc
- }
- if (args.find("--input-dir") != args.end()) {
- input_dir = args["--input-dir"]; // Directory containing shader sources
- }
- if (args.find("--output-dir") != args.end()) {
- output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
- }
- if (args.find("--target-hpp") != args.end()) {
- target_hpp = args["--target-hpp"]; // Path to generated header file
- }
- if (args.find("--target-cpp") != args.end()) {
- target_cpp = args["--target-cpp"]; // Path to generated cpp file
- }
- if (args.find("--no-clean") != args.end()) {
- no_clean = true; // Keep temporary SPIR-V files in output-dir after build
- }
-
- if (!directory_exists(input_dir)) {
- std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
- return EXIT_FAILURE;
- }
-
- if (!directory_exists(output_dir)) {
- if (!create_directory(output_dir)) {
- std::cerr << "Error creating output directory: " << output_dir << "\n";
- return EXIT_FAILURE;
- }
- }
-
- process_shaders();
-
- write_output_files();
-
- return EXIT_SUCCESS;
-}