]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add opencl backend (skip) (llama/10693)
authorlhez <redacted>
Tue, 14 Jan 2025 07:24:03 +0000 (09:24 +0200)
committerGeorgi Gerganov <redacted>
Tue, 14 Jan 2025 08:38:01 +0000 (10:38 +0200)
---------

Co-authored-by: Skyler Szot <redacted>
Co-authored-by: Shangqing Gu <redacted>
Co-authored-by: Alexander Angus <redacted>
Co-authored-by: Hongqiang Wang <redacted>
Co-authored-by: Max Krasnyansky <redacted>
12 files changed:
ggml/src/ggml-opencl/CMakeLists.txt [new file with mode: 0644]
ggml/src/ggml-opencl/ggml-opencl.cpp [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/embed_kernel.py [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl [new file with mode: 0644]

diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt
new file mode 100644 (file)
index 0000000..45328a6
--- /dev/null
@@ -0,0 +1,147 @@
+find_package(OpenCL REQUIRED)
+find_package(Python3 REQUIRED)
+
+set(TARGET_NAME ggml-opencl)
+
+ggml_add_backend_library(${TARGET_NAME}
+                         ggml-opencl.cpp
+                         ../../include/ggml-opencl.h)
+target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES})
+target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS})
+
+if (GGML_OPENCL_PROFILING)
+    message(STATUS "OpenCL profiling enabled (increases CPU overhead)")
+    add_compile_definitions(GGML_OPENCL_PROFILING)
+endif ()
+
+add_compile_definitions(GGML_OPENCL_SOA_Q)
+
+if (GGML_OPENCL_USE_ADRENO_KERNELS)
+    message(STATUS "OpenCL will use matmul kernels optimized for Adreno")
+    add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS)
+endif ()
+
+if (GGML_OPENCL_EMBED_KERNELS)
+    add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)
+
+    set(OPENCL_CL_SOURCE_EMBED         "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h")
+    set(OPENCL_MM_CL_SOURCE_EMBED      "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h")
+    set(OPENCL_CVT_CL_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h")
+
+    set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED             "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h")
+    set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h")
+    set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED          "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h")
+    set(OPENCL_TRANSPOSE_16_SOURCE_EMBED               "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h")
+    set(OPENCL_TRANSPOSE_32_SOURCE_EMBED               "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h")
+    set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED            "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h")
+
+    set(EMBED_KERNEL_SCRIPT             "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
+    file(MAKE_DIRECTORY                 "${CMAKE_BINARY_DIR}/autogenerated")
+
+    include_directories("${CMAKE_BINARY_DIR}/autogenerated")
+
+    # Python must be accessible from command line
+    add_custom_command(
+        OUTPUT ${OPENCL_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl
+            ${OPENCL_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl
+            ${OPENCL_MM_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_mm.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl
+            ${OPENCL_CVT_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_cvt.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl
+            ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl
+            ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
+            ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl
+            ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_16.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl
+            ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_32.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl
+            ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_32_16.cl.h"
+    )
+
+    target_sources(${TARGET_NAME} PRIVATE
+                   ${OPENCL_CL_SOURCE_EMBED}
+                   ${OPENCL_MM_CL_SOURCE_EMBED}
+                   ${OPENCL_CVT_CL_SOURCE_EMBED}
+                   ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+                   ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+                   ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED})
+else ()
+    # copy ggml-opencl.cl to bin directory
+    configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY)
+
+    configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY)
+endif ()
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
new file mode 100644 (file)
index 0000000..ed90e47
--- /dev/null
@@ -0,0 +1,4004 @@
+#define CL_TARGET_OPENCL_VERSION 220
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS
+
+// suppress warnings in CL headers for GCC and Clang
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#ifdef __clang__
+#pragma GCC diagnostic ignored "-Wgnu-anonymous-struct"
+#endif
+
+#include "ggml-opencl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+#include "ggml.h"
+
+#include <CL/cl.h>
+
+#include <string.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <atomic>
+#include <fstream>
+#include <limits>
+#include <vector>
+#include <string>
+#include <cmath>
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define UNUSED(x) (void)(x)
+
+#define CL_CHECK(err)                                               \
+    do {                                                            \
+        cl_int err_ = (err);                                        \
+        if (err_ != CL_SUCCESS) {                                   \
+            GGML_LOG_ERROR("ggml_opencl: %s error %d at %s:%d\n",  \
+                #err, err_, __FILE__, __LINE__);                    \
+            GGML_ASSERT(0);                                         \
+        }                                                           \
+    } while (0)
+
+//------------------------------------------------------------------------------
+// OpenCL
+//------------------------------------------------------------------------------
+
+bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
+
+enum GPU_FAMILY {
+    ADRENO,
+    INTEL,
+    UNKNOWN,
+};
+
+enum ADRENO_GPU_GEN {
+    ADRENO_UNKNOWN,
+    A7X,
+    A8X,
+    X1E,
+};
+
+static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {
+    if (strstr(device_name, "730") ||
+        strstr(device_name, "740") ||
+        strstr(device_name, "750")) {
+        return ADRENO_GPU_GEN::A7X;
+    }
+
+    if (strstr(device_name, "830")) {
+        return ADRENO_GPU_GEN::A8X;
+    }
+
+    if (strstr(device_name, "X1")) {
+        return ADRENO_GPU_GEN::X1E;
+    }
+
+    return ADRENO_GPU_GEN::ADRENO_UNKNOWN;
+}
+
+static int get_adreno_cl_compiler_version(const char *driver_version) {
+    std::string driver_ver_str(driver_version);
+    size_t compiler_ver_pos = driver_ver_str.find("E031");
+    size_t compiler_ver_len = 13;
+    size_t compiler_ver_offset = 5;
+
+    if (compiler_ver_pos == std::string::npos) {
+        compiler_ver_pos = driver_ver_str.find("DX");
+        if (compiler_ver_pos == std::string::npos) {
+            return -1;
+        }
+        compiler_ver_len = 11;
+        compiler_ver_offset = 3;
+    }
+
+    std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len);
+    std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2);
+    return std::atoi(major_ver_str.c_str());
+}
+
+// backend device context
+struct ggml_backend_opencl_device_context {
+    cl_platform_id platform;
+    std::string platform_name;
+
+    cl_device_id device;
+    std::string device_name;
+};
+
+// backend context
+struct ggml_backend_opencl_context {
+    cl_device_id device;
+    std::string device_name;
+
+    std::string driver_version;
+
+    GPU_FAMILY gpu_family;
+    ADRENO_GPU_GEN adreno_gen;
+
+    cl_int alignment;
+    size_t max_alloc_size;
+    bool fp16_support;
+
+    int adreno_wave_size;
+
+    cl_context context;
+    cl_command_queue queue;
+
+    cl_program program;
+    cl_program program_1;
+    cl_program program_2;
+
+    cl_kernel kernel_add, kernel_add_row;
+    cl_kernel kernel_mul, kernel_mul_row;
+    cl_kernel kernel_scale;
+    cl_kernel kernel_silu, kernel_silu_4;
+    cl_kernel kernel_gelu, kernel_gelu_4;
+    cl_kernel kernel_relu;
+    cl_kernel kernel_clamp;
+    cl_kernel kernel_norm;
+    cl_kernel kernel_rms_norm;
+    cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
+    cl_kernel kernel_soft_max, kernel_soft_max_4;
+    cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
+    cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
+    cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
+    cl_kernel kernel_mul_mat_f32_f32;
+    cl_kernel kernel_mul_mat_f16_f16;
+    cl_kernel kernel_mul_mat_f16_f32_1row;
+    cl_kernel kernel_mul_mat_f16_f32;
+    cl_kernel kernel_mul_mat_f16_f32_l4;
+    cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
+    cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat;
+    cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
+    cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0,
+              kernel_mul_mat_q4_0_f32_flat_img_v0;
+    cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
+    cl_kernel kernel_mul_mv_q6_K_f32;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    // Transpose kernels
+    cl_program program_transpose_32;
+    cl_program program_transpose_32_16;
+    cl_program program_transpose_16;
+    cl_kernel kernel_transpose_32;
+    cl_kernel kernel_transpose_32_16;
+    cl_kernel kernel_transpose_16;
+
+    cl_mem A_s_d_max;            // max scale buffer size for transpose
+    cl_mem A_q_d_max;            // max weight buffer size for transpose
+    cl_mem B_d_max;              // max activation buffer size for transpose
+
+    // Gemm and Gemv related programs, kernels, etc
+    cl_program program_CL_gemm;
+    cl_program program_CL_gemv_general;
+    cl_program program_CL_gemv_4096_1_11008;
+    cl_program program_CL_gemv_4096_1_4096;
+    cl_program program_CL_gemv_11008_1_4096;
+    cl_program program_CL_gemv_32000_1_4096;
+    cl_kernel CL_mul_mat_Ab_Bi_8x4;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+};
+
+static ggml_backend_device                 g_ggml_backend_opencl_device;
+static ggml_backend_opencl_device_context  g_ggml_ctx_dev_main {
+    /*.platform         =*/ nullptr,
+    /*.platform_nane    =*/ "",
+    /*.device           =*/ nullptr,
+    /*.device_name      =*/ "",
+};
+
+static int ggml_backend_opencl_n_devices = 0;
+
+// Profiling
+#ifdef GGML_OPENCL_PROFILING
+struct ProfilingInfo {
+    std::string op_name;
+    std::string kernel_name;
+    // Kernel execution time in nanoseconds.
+    cl_ulong duration_ns;
+    // Global and local work sizes.
+    size_t global_size[3];
+    size_t local_size[3];
+    // Op output size.
+    size_t output_size[4];
+};
+
+std::vector<ProfilingInfo> g_profiling_info;
+#endif
+
+inline std::string read_file(const std::string &path) {
+  std::ifstream ifs(path);
+  if (!ifs) {
+    return "";
+  }
+  std::string text;
+  ifs.seekg(0, std::ios::end);
+  text.resize(ifs.tellg());
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(&text[0], text.size());
+  return text;
+}
+
+static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) {
+    cl_program p;
+    char *program_log;
+    size_t program_size;
+    size_t log_size;
+    int err;
+
+    program_size = strlen(program_buffer);
+
+    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
+    if(err < 0) {
+        GGML_LOG_ERROR("OpenCL error creating program");
+        exit(1);
+    }
+
+    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
+    if(err < 0) {
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
+        program_log = (char*) malloc(log_size + 1);
+        program_log[log_size] = '\0';
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
+        GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log);
+        free(program_log);
+        exit(1);
+    }
+
+    return p;
+}
+
+static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
+    static bool initialized = false;
+    static ggml_backend_opencl_context *backend_ctx = nullptr;
+
+    if (initialized) {
+        return backend_ctx;
+    }
+
+    ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
+    GGML_ASSERT(dev_ctx);
+    GGML_ASSERT(dev_ctx->platform == nullptr);
+    GGML_ASSERT(dev_ctx->device == nullptr);
+    GGML_ASSERT(backend_ctx == nullptr);
+
+    initialized = true;
+    backend_ctx = new ggml_backend_opencl_context();
+    backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
+
+    cl_int err;
+
+#ifdef GGML_PROFILE_OPENCL
+    GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n");
+#endif
+
+    struct cl_device;
+    struct cl_platform {
+        cl_platform_id id;
+        unsigned number;
+        char name[128];
+        char vendor[128];
+        struct cl_device * devices;
+        unsigned n_devices;
+        struct cl_device * default_device;
+    };
+
+    struct cl_device {
+        struct cl_platform * platform;
+        cl_device_id id;
+        unsigned number;
+        cl_device_type type;
+        char name[128];
+    };
+
+    enum { NPLAT = 16, NDEV = 16 };
+
+    struct cl_platform platforms[NPLAT];
+    unsigned n_platforms = 0;
+    struct cl_device devices[NDEV];
+    unsigned n_devices = 0;
+    struct cl_device * default_device = NULL;
+
+    cl_platform_id platform_ids[NPLAT];
+    if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {
+        GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n");
+        return backend_ctx;
+    }
+
+    for (unsigned i = 0; i < n_platforms; i++) {
+        struct cl_platform * p = &platforms[i];
+        p->number = i;
+        p->id = platform_ids[i];
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
+
+        cl_device_id device_ids[NDEV];
+        cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
+        if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
+            p->n_devices = 0;
+        } else {
+            CL_CHECK(clGetDeviceIDsError);
+        }
+        p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
+        p->default_device = NULL;
+
+        for (unsigned j = 0; j < p->n_devices; j++) {
+            struct cl_device * d = &devices[n_devices];
+            d->number = n_devices++;
+            d->id = device_ids[j];
+            d->platform = p;
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
+
+            if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
+                p->default_device = d;
+            }
+        }
+
+        if (default_device == NULL && p->default_device != NULL) {
+            default_device = p->default_device;
+        }
+    }
+
+    if (n_devices == 0) {
+        GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n");
+        return backend_ctx;
+    }
+
+    char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
+    char * user_device_string = getenv("GGML_OPENCL_DEVICE");
+    int user_platform_number = -1;
+    int user_device_number = -1;
+
+    unsigned n;
+    if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
+        user_platform_number = (int)n;
+    }
+    if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
+        user_device_number = (int)n;
+    }
+    if (user_platform_number != -1 && user_device_number != -1) {
+        cl_platform* platform = &platforms[user_platform_number];
+        if ((unsigned)user_device_number >= platform->n_devices) {
+            GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number);
+            exit(1);
+        }
+        default_device = &platform->devices[user_device_number];
+    } else {
+
+        struct cl_device * selected_devices = devices;
+        unsigned n_selected_devices = n_devices;
+
+        if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
+            for (unsigned i = 0; i < n_platforms; i++) {
+                struct cl_platform * p = &platforms[i];
+                if (strstr(p->name, user_platform_string) != NULL ||
+                    strstr(p->vendor, user_platform_string) != NULL) {
+                    user_platform_number = (int)i;
+                    break;
+                }
+            }
+            if (user_platform_number == -1) {
+                GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
+                exit(1);
+            }
+        }
+        if (user_platform_number != -1) {
+            struct cl_platform * p = &platforms[user_platform_number];
+            selected_devices = p->devices;
+            n_selected_devices = p->n_devices;
+            default_device = p->default_device;
+            if (n_selected_devices == 0) {
+                GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
+                exit(1);
+            }
+        }
+
+        if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
+            for (unsigned i = 0; i < n_selected_devices; i++) {
+                struct cl_device * d = &selected_devices[i];
+                if (strstr(d->name, user_device_string) != NULL) {
+                    user_device_number = d->number;
+                    break;
+                }
+            }
+            if (user_device_number == -1) {
+                GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string);
+                exit(1);
+            }
+        }
+        if (user_device_number != -1) {
+            selected_devices = &devices[user_device_number];
+            n_selected_devices = 1;
+            default_device = &selected_devices[0];
+        }
+
+        GGML_ASSERT(n_selected_devices > 0);
+
+        if (default_device == NULL) {
+            default_device = &selected_devices[0];
+        }
+    }
+
+    GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
+    GGML_LOG_INFO("ggml_opencl: selecting device: '%s'\n", default_device->name);
+    if (default_device->type != CL_DEVICE_TYPE_GPU) {
+        GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
+    }
+
+    dev_ctx->platform = default_device->platform->id;
+    dev_ctx->device = default_device->id;
+    backend_ctx->device = default_device->id;
+
+    if (strstr(default_device->name, "Adreno")) {
+        backend_ctx->gpu_family = GPU_FAMILY::ADRENO;
+        backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name);
+
+        // Default wave size is 128, A8x uses 64.
+        if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) {
+            backend_ctx->adreno_wave_size = 64;
+        } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X ||
+                   backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) {
+            backend_ctx->adreno_wave_size = 128;
+        } else {
+            backend_ctx->adreno_wave_size = 128;
+            GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, "
+                "using wave size %d, "
+                "may not work as expected\n",
+                backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size);
+        }
+    } else if (strstr(default_device->name, "Intel")) {
+        backend_ctx->gpu_family = GPU_FAMILY::INTEL;
+    } else {
+        GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name);
+        backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
+        return backend_ctx;
+    }
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
+        GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; "
+            "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n");
+        return backend_ctx;
+    }
+#endif
+
+    // Populate backend device name
+    dev_ctx->platform_name = default_device->platform->name;
+    dev_ctx->device_name = default_device->name;
+    backend_ctx->device_name = default_device->name;
+
+    // A local ref of cl_device_id for convenience
+    cl_device_id device = backend_ctx->device;
+
+    // Check device OpenCL version, OpenCL 2.0 or above is required
+    size_t device_ver_str_size;
+    clGetDeviceInfo(device, CL_DEVICE_VERSION, 0, NULL, &device_ver_str_size);
+    char *device_ver_buffer = (char *)alloca(device_ver_str_size + 1);
+    clGetDeviceInfo(device, CL_DEVICE_VERSION, device_ver_str_size, device_ver_buffer, NULL);
+    device_ver_buffer[device_ver_str_size] = '\0';
+    GGML_LOG_INFO("ggml_opencl: device OpenCL version: %s\n", device_ver_buffer);
+
+    if (strstr(device_ver_buffer, "OpenCL 2") == NULL &&
+        strstr(device_ver_buffer, "OpenCL 3") == NULL) {
+        GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n");
+        return backend_ctx;
+    }
+
+    // Check driver version
+    size_t driver_version_str_size;
+    clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size);
+    char *driver_version = (char *)alloca(driver_version_str_size + 1);
+    clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL);
+    driver_version[driver_version_str_size] = '\0';
+    GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version);
+    backend_ctx->driver_version = driver_version;
+
+    int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
+    bool has_vector_subgroup_broadcast =
+        adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17;
+    GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
+        has_vector_subgroup_broadcast ? "true" : "false");
+
+    size_t ext_str_size;
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size);
+    char *ext_buffer = (char *)alloca(ext_str_size + 1);
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL);
+    ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated
+    // Check if ext_buffer contains cl_khr_fp16
+    backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL;
+    GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false");
+
+    // fp16 is required
+    if (!backend_ctx->fp16_support) {
+        GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n");
+        return backend_ctx;
+    }
+
+    // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes
+    // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x)
+    if (strstr(device_ver_buffer, "OpenCL 3") &&
+        strstr(ext_buffer, "cl_khr_subgroups") == NULL &&
+        strstr(ext_buffer, "cl_intel_subgroups") == NULL) {
+        GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) "
+            "(note that subgroups is an optional feature in OpenCL 3.0)\n");
+        return backend_ctx;
+    }
+
+    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &backend_ctx->alignment, NULL));
+    GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment);
+
+    clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
+    GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
+
+    // Check SVM.
+    cl_device_svm_capabilities svm_caps;
+    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
+    GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n",
+        svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n",
+        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n",
+        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n",
+        svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
+
+    // Print out configurations
+#ifdef GGML_OPENCL_SOA_Q
+    GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n");
+#endif // GGML_OPENCL_SOA_Q
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    cl_context_properties properties[] = {
+        (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0
+    };
+
+    CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
+
+    // A local ref of cl_context for convenience
+    cl_context context = backend_ctx->context;
+
+    //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
+    //    (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :
+    //    (queue = clCreateCommandQueue(context, device, 0, &err), err)
+    //)));
+    cl_command_queue_properties command_queue_props = 0;
+#ifdef GGML_OPENCL_PROFILING
+    command_queue_props |= CL_QUEUE_PROFILING_ENABLE;
+#endif
+    CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src {
+        #include "ggml-opencl.cl.h"
+    };
+#else
+    const std::string kernel_src = read_file("ggml-opencl.cl");
+#endif
+
+    std::string compile_opts =
+        "-cl-std=CL2.0 -cl-mad-enable -cl-unsafe-math-optimizations "
+        "-cl-finite-math-only -cl-fast-relaxed-math ";
+    backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts);
+
+    // Non matmul kernels.
+    CL_CHECK((backend_ctx->kernel_get_rows_f32       = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_get_rows_f16       = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_get_rows_q4_0      = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_add                = clCreateKernel(backend_ctx->program, "kernel_add", &err), err));
+    CL_CHECK((backend_ctx->kernel_add_row            = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul                = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_row            = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err));
+    CL_CHECK((backend_ctx->kernel_scale              = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err));
+    CL_CHECK((backend_ctx->kernel_silu               = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err));
+    CL_CHECK((backend_ctx->kernel_silu_4             = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_gelu               = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
+    CL_CHECK((backend_ctx->kernel_gelu_4             = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_relu               = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
+    CL_CHECK((backend_ctx->kernel_clamp              = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
+    CL_CHECK((backend_ctx->kernel_norm               = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
+    CL_CHECK((backend_ctx->kernel_rms_norm           = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err));
+    CL_CHECK((backend_ctx->kernel_diag_mask_inf      = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err));
+    CL_CHECK((backend_ctx->kernel_diag_mask_inf_8    = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err));
+    CL_CHECK((backend_ctx->kernel_soft_max           = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err));
+    CL_CHECK((backend_ctx->kernel_soft_max_4         = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_norm_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_norm_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_neox_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_neox_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f32_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f32_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err));
+
+    // Matmul kernels.
+    CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row   = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4     = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32       = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v     = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err));
+
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat  = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_convert_block_q4_0     = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_restore_block_q4_0     = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err));
+
+    // Load additional mulmat kernels.
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_1 {
+        #include "ggml-opencl_mm.cl.h"
+    };
+#else
+    const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl");
+#endif
+    backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts);
+
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat      = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat     = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32                  = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0         = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0     = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err));
+
+    // Load additional data conversion kernels.
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_2 {
+        #include "ggml-opencl_cvt.cl.h"
+    };
+#else
+    const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl");
+#endif
+    backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts);
+
+    CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle     = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
+
+    // Kernels for Adreno
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_32_src {
+        #include "ggml-opencl_transpose_32.cl.h"
+    };
+#else
+    const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl");
+#endif
+    backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_32_16_src {
+        #include "ggml-opencl_transpose_32_16.cl.h"
+    };
+#else
+    const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl");
+#endif
+    backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_16_src {
+        #include "ggml-opencl_transpose_16.cl.h"
+    };
+#else
+    const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl");
+#endif
+    backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err));
+
+    // Gemv general
+    std::string CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemv_general {
+        #include "ggml-opencl_gemv_noshuffle_general.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl");
+#endif
+
+    backend_ctx->program_CL_gemv_general = build_program_from_source(
+        context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 2048, 16384
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=2048 "
+        " -DBLOCK_STRIDE_A=16384 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemv {
+        #include "ggml-opencl_gemv_noshuffle.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl");
+#endif
+
+    backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 2048, 16384
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=2048 "
+        " -DBLOCK_STRIDE_A=16384 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 5504, 44032
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=5504 "
+        " -DBLOCK_STRIDE_A=44032 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 16000, 128000
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=16000 "
+        " -DBLOCK_STRIDE_A=128000 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemm
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemm {
+        #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl");
+#endif
+    backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
+
+    // Allocate intermediate buffers and images
+    size_t max_A_q_d_bytes = 311164928;
+    size_t max_A_s_d_bytes = 38895616;
+    size_t max_B_d_bytes = 45088768;
+
+    CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
+    CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
+    CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err));
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    // For now we support a single devices
+    ggml_backend_opencl_n_devices = 1;
+
+    return backend_ctx;
+}
+
+static void ggml_cl2_free(void) {
+#ifdef GGML_OPENCL_PROFILING
+    FILE * fperf = fopen("cl_profiling.csv", "w");
+    if (!fperf) {
+        GGML_LOG_ERROR("Failed to open cl_profiling.csv\n");
+        return;
+    }
+
+    float total_kernel_time = 0;
+    fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
+    for (const ProfilingInfo & info : g_profiling_info) {
+        total_kernel_time += info.duration_ns/1.e6f;
+        fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
+            info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
+            info.global_size[0], info.global_size[1], info.global_size[2],
+            info.local_size[0], info.local_size[2], info.local_size[2],
+            info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
+    }
+    fclose(fperf);
+
+    GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Tensor extra management
+//------------------------------------------------------------------------------
+struct ggml_tensor_extra_cl {
+    // The buffer object that holds the data.
+    cl_mem data_device;
+    // The offset into the buffer object. This is primarily for scratch buffer
+    // and view operation.
+    // NB: this offset no longer includes view offset (view_offs). Whenever this
+    // offset is used, view_offs should be considered.
+    cl_ulong offset;
+    // The actual size of the cl_mem object. This is needed when returning the
+    // block to the pool.
+    size_t actual_size;
+
+    void reset() {
+        data_device = nullptr;
+        offset = 0;
+        actual_size = 0;
+    }
+};
+
+// Additional tensor extra structs for quantized tensors.
+// These tensors are loaded from files and should not be allocated in scratch --
+// they should always be allocated from the pool. Hence, they do not have an
+// `offset`, which indicate their locations in the scratch buffer.
+struct ggml_tensor_extra_cl_q4_0 {
+    // Quantized values.
+    cl_mem q = nullptr;
+    // Quantized values in image1d_buffer_t.
+    cl_mem q_img = nullptr;
+    // Scales.
+    cl_mem d = nullptr;
+    // Scales in image1d_buffer_t.
+    cl_mem d_img = nullptr;
+    // Size of quantized values.
+    size_t size_q = 0;
+    // Size of scales.
+    size_t size_d = 0;
+
+    ~ggml_tensor_extra_cl_q4_0() {
+        reset();
+    }
+
+    void reset() {
+        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
+        // They must be properly released so that the original buffer can be
+        // properly released to avoid memory leak.
+        if (q != nullptr) {
+            CL_CHECK(clReleaseMemObject(q));
+            q = nullptr;
+        }
+        if (d != nullptr) {
+            CL_CHECK(clReleaseMemObject(d));
+            d = nullptr;
+        }
+        // Currently, q_img and d_img are only initialized when SMALL_ALLOC is
+        // enabled. They point to the images in ggml_backend_opencl_buffer_context.
+        // So, there is no need to release them here.
+        // TODO: initialize them for non SMALL_PATH path, or remove them.
+        q_img = nullptr;
+        d_img = nullptr;
+        size_q = 0;
+        size_d = 0;
+    }
+};
+
+//------------------------------------------------------------------------------
+// Backend API
+//------------------------------------------------------------------------------
+
+//
+// backend
+//
+static const char * ggml_backend_opencl_name(ggml_backend_t backend) {
+    return "OpenCL";
+
+    UNUSED(backend);
+}
+
+static void ggml_backend_opencl_free(ggml_backend_t backend) {
+    ggml_cl2_free();
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(data);
+    GGML_UNUSED(offset);
+    GGML_UNUSED(size);
+}
+
+static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(data);
+    GGML_UNUSED(offset);
+    GGML_UNUSED(size);
+}
+
+static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(src);
+    GGML_UNUSED(dst);
+    return false;
+}
+
+static void ggml_backend_opencl_synchronize(ggml_backend_t backend) {
+    GGML_UNUSED(backend);
+}
+
+static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        bool ok = ggml_cl_compute_forward(backend, node);
+        if (!ok) {
+            GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+        }
+        GGML_ASSERT(ok);
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    GGML_UNUSED(dev);
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+            return true;
+        case GGML_OP_GET_ROWS:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                    return true;
+                case GGML_TYPE_Q4_0:
+#ifdef GGML_OPENCL_SOA_Q
+                    // We do not support flattened Q4_0 (and possibly other Q's)
+                    return false;
+#else // GGML_OPENCL_SOA_Q
+                    return true;
+#endif // GGML_OPENCL_SOA_Q
+                default:
+                    return false;
+            }
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                    switch (op->type) {
+                        case GGML_TYPE_F16:
+                        case GGML_TYPE_F32:
+                            return true;
+                        default:
+                            return false;
+                    }
+                case GGML_TYPE_F16:
+                    switch (op->type) {
+                        case GGML_TYPE_F16:
+                        case GGML_TYPE_F32:
+                            return true;
+                        default:
+                            return false;
+                    }
+                default:
+                    return false;
+            }
+        case GGML_OP_ADD:
+        case GGML_OP_SCALE:
+        case GGML_OP_MUL:
+            return true;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                   return ggml_is_contiguous(op->src[0]);
+                default:
+                    return false;
+            }
+        case GGML_OP_CLAMP:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+            return true;
+        case GGML_OP_MUL_MAT:
+            if (op->src[0]->type == GGML_TYPE_F16) {
+                return true;
+            } else if (op->src[0]->type == GGML_TYPE_F32) {
+                return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
+            } else if (op->src[0]->type == GGML_TYPE_Q4_0 ||
+                       op->src[0]->type == GGML_TYPE_Q6_K) {
+                return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
+            }
+            return false;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+            return op->ne[3] == 1;
+        case GGML_OP_ROPE:
+            return true;
+        default:
+            return false;
+    }
+}
+
+// Forward declaration - implementation appears later in the file.
+static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type);
+
+static ggml_guid_t ggml_backend_opencl_guid() {
+    static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe };
+    return &guid;
+}
+
+static ggml_backend_i ggml_backend_opencl_i = {
+    /* .get_name                = */ ggml_backend_opencl_name,
+    /* .free                    = */ ggml_backend_opencl_free,
+    /* .set_tensor_async        = */ NULL,  /* ggml_backend_opencl_set_tensor_async */
+    /* .get_tensor_async        = */ NULL,  /* ggml_backend_opencl_get_tensor_async */
+    /* .cpy_tensor_async        = */ NULL,  /* ggml_backend_opencl_cpy_tensor_async */
+    /* .synchronize             = */ NULL,  /* ggml_backend_opencl_synchronize */
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_opencl_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+ggml_backend_t ggml_backend_opencl_init(void) {
+    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0);
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_opencl_guid(),
+        /* .interface = */ ggml_backend_opencl_i,
+        /* .device    = */ dev,
+        /* .context   = */ backend_ctx
+    };
+
+    return backend;
+}
+
+bool ggml_backend_is_opencl(ggml_backend_t backend) {
+    return backend && backend->iface.get_name == ggml_backend_opencl_name;
+}
+
+//
+// buffer
+//
+struct ggml_backend_opencl_buffer_context {
+    // A buffer context can hold multiple cl_mem objects. This is for flattening
+    // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where
+    // each tensor is allocated a separate buffer. When flattening is enabled
+    // with small allocation, each tensor is backed by two cl_mem objects (for
+    // quants and scales) packed into a backend_opencl_buffer.
+    ggml_backend_opencl_buffer_context(cl_mem buf)
+        : name("OpenCL") {
+        buffer.push_back(buf);
+    }
+
+    ~ggml_backend_opencl_buffer_context() {
+        for (cl_mem buf : buffer) {
+            CL_CHECK(clReleaseMemObject(buf));
+        }
+        for (cl_mem im : img) {
+            CL_CHECK(clReleaseMemObject(im));
+        }
+
+        // Delete all extras to trigger their destructors
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
+            delete e;
+        }
+    }
+
+    ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
+        ggml_tensor_extra_cl * extra;
+        if (temp_tensor_extras.empty()) {
+            extra = new ggml_tensor_extra_cl();
+        } else {
+            extra = temp_tensor_extras.back();
+            temp_tensor_extras.pop_back();
+        }
+
+        temp_tensor_extras_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
+    ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() {
+        ggml_tensor_extra_cl_q4_0 * extra;
+        if (temp_tensor_extras_q4_0.empty()) {
+            extra = new ggml_tensor_extra_cl_q4_0();
+        } else {
+            extra = temp_tensor_extras_q4_0.back();
+            temp_tensor_extras_q4_0.pop_back();
+        }
+
+        temp_tensor_extras_q4_0_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
+    void reset() {
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
+            temp_tensor_extras.push_back(e);
+        }
+        temp_tensor_extras_in_use.clear();
+
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
+            temp_tensor_extras_q4_0.push_back(e);
+        }
+        temp_tensor_extras_q4_0_in_use.clear();
+    }
+
+    // Pools for extras. Available extras are in `temp_tensor_extras`. Extras
+    // being used are in `temp_tensor_extras_in_use`. At the first run, new
+    // extras get created and put in `in_use`. When the buffer is reset via
+    // the `reset` callback, all extras in `in_use` get moved to available extras
+    // for reuse.
+    std::vector<ggml_tensor_extra_cl *> temp_tensor_extras;
+    std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
+    std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
+    std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
+
+    // The buffer_context is initially created by ggml_backend_buft_alloc_buffer
+    // before any tensor is initialized (at the beginning of alloc_tensor_range).
+    // Hence, there is alway a buffer object in this vector. When each tensor is
+    // being initialized, this original buffer object will be released if both
+    // flattening and small allocation are enabled, and additional buffer
+    // objects will be created in init_tensor to represent flattened quantized
+    // weights.
+    std::vector<cl_mem> buffer;
+    // These are image1d_buffer_t objects that wrap around the quants and scales.
+    // For Q4_0 quantization, there should be two of them - one for quants and
+    // one for scales. They should be populated only when flattening and small
+    // allocation are enabled.
+    std::vector<cl_mem> img;
+    std::string name;
+};
+
+static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
+
+static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    delete ctx;
+}
+
+static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return cl_ptr_base;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+
+    ggml_cl2_init(buffer->buft->device);
+
+    if (tensor->view_src != nullptr) {
+        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+
+        ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra;
+        GGML_ASSERT(view_extra && "view_extra is nullptr?");
+
+        // Reuse extra of the parent tensor. The offset of this view tensor
+        // becomes `extra->offset + view_offs` and needs to be calculated when
+        // it is used. This changes is needed because of the change to
+        // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640.
+        // `buffer` passed in here will always be `tensor->buffer`. It is OK
+        // to allocate extras from the same buffer context for ordinary
+        // intermediate tensors. But for views into kv cache tensors, doing so
+        // would mess up the extras used by kv cache.
+        // Before #7640, `buffer` is for intermediate tensors, which is always
+        // different from that of kv cache tensors.
+        //
+        // NB: now extra->offset no longer accounts for view_offs.
+        // NB: this should not apply to weight tensors (for end-to-end runs, but
+        //     may apply for test-backend-ops).
+        // FIXME: if any unexpected results are seen, double check the offset -
+        // there could be other places that need fix.
+        tensor->extra = view_extra;
+    } else {
+        {
+            size_t offset = (char *)tensor->data - (char *)cl_ptr_base;
+
+            ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra();
+            extra->offset = offset;
+            extra->data_device = ctx->buffer[0];
+            extra->actual_size = ggml_nbytes(tensor);
+
+            tensor->extra = extra;
+        }
+    }
+}
+
+// The optimized gemm and gemv kernels are used for large matrices without batch.
+// tensor is the quantized weights matrix.
+inline bool use_adreno_kernels(const ggml_tensor *tensor) {
+    return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 &&
+            tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
+
+    cl_context context = backend_ctx->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+#ifdef GGML_OPENCL_SOA_Q
+    // We separate the quantized bits and scale from block_q4_0 by using an
+    // additional kernel, where each thread handles a block. We first read the
+    // original weights into a temporary buffer, then create two separate
+    // buffers for quantized bits and scales, which are then populated by the
+    // conversion kernel.
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        // Tensors should have been preallocated, therefore they should
+        // already have ggml_tensor_extra_cl as extra.
+        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
+        GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
+
+        // Allocate the new extra and create aliases from the original.
+        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+        ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0();
+
+        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
+        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
+        GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+        CL_CHECK(clEnqueueWriteBuffer(
+            queue, data_device, CL_TRUE, 0,
+            ggml_nbytes(tensor), data, 0, NULL, NULL));
+
+        // We consider the specified offset arg as always, although For weights
+        // the offset arg should be 0 (we do not assert this).
+        //GGML_ASSERT(offset == 0);
+
+        // We create subbuffers from the original tensor buffer for scales and
+        // quants - i.e., scales and quants are aliases into the buffer obejct
+        // that backs the original tensor. This is a cleaner way to adapt to the
+        // new memory management.
+        // In the old code, we allocate new buffers for scales and quants
+        // respectively, which could still be done but would result in double
+        // allocation; properly deallocating the preallocated buffer that backs
+        // the tensors is tricky and would leak the backend specific information
+        // into the general backend code.
+        // Does this create misaligned subbuffers (alignment is 1024) in certain
+        // cases ?
+        cl_buffer_region region;
+
+        // The original tensor memory is divided into scales and quants, i.e.,
+        // we first store scales, then quants.
+        // Create subbuffer for scales.
+        region.origin = extra_orig->offset + tensor->view_offs + offset;
+        region.size = size_d;
+        extra->d = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
+        CL_CHECK(err);
+
+        // Create subbuffer for quants.
+        region.origin = extra_orig->offset + tensor->view_offs + offset + size_d;
+        region.size = size_q;
+        extra->q = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
+        CL_CHECK(err);
+
+        //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+
+        // The optimized kernels need weights in natural order, so unshuffle.
+        if (use_adreno_kernels(tensor)) {
+            kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle;
+        }
+    #else
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+    #endif // GGML_OPENCL_USE_ADRENO_KERNELS
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clReleaseMemObject(data_device));
+
+        tensor->extra = extra;
+
+        // transpose the weights and scales
+    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        // Only do transpose for large, non batched matrix
+        // TODO: use preallocated images instead of sub-buffer then image
+        if (use_adreno_kernels(tensor)) {
+        // <----------------------------------------------------------------------------------> //
+        // start transpose
+        // <----------------------------------------------------------------------------------> //
+        int M = tensor->ne[1];   // ne01
+        int K = tensor->ne[0];   // ne00
+
+        // transpose is out of place, so we need to allocate transposed buffers
+        // <----------------------------------------------------------------------------------> //
+        // use sub_buffer of max buffer size instead
+
+        size_t q_size_bytes = K * M / 8 * sizeof(float);
+        cl_buffer_region region;
+        region.origin = 0;
+        region.size = q_size_bytes;
+        cl_mem qT_d = clCreateSubBuffer(
+            backend_ctx->A_q_d_max,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            &region,
+            &err);
+        // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
+        CL_CHECK(err);
+
+        // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float);
+        size_t d_size_bytes = M * (K / 32) * 2;
+        region.origin = 0;
+        region.size = d_size_bytes;
+        cl_mem dT_d = clCreateSubBuffer(
+            backend_ctx->A_s_d_max,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            &region,
+            &err);
+        // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
+        CL_CHECK(err);
+
+        // <----------------------------------------------------------------------------------> //
+
+
+        // create images from the buffers
+        // <----------------------------------------------------------------------------------> //
+        cl_mem q_d_image1D;
+        cl_mem d_d_image1D;
+        cl_mem qT_d_image1D;
+        cl_mem dT_d_image1D;
+
+        cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        cl_image_desc img_desc_1d;
+
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 8 / 4;
+        img_desc_1d.buffer = extra->q;
+        q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 8 / 4;
+        img_desc_1d.buffer = qT_d;
+        qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 32 / 4 / 2;
+        img_desc_1d.buffer = extra->d;
+        d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 32 / 4 / 2;
+        img_desc_1d.buffer = dT_d;
+        dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+        // <----------------------------------------------------------------------------------> //
+
+        // set up and call the transpose kernels
+        // <----------------------------------------------------------------------------------> //
+        // weights
+        int height_q = M / 8;
+        int width_q = K / 8 / 4;
+        kernel = backend_ctx->kernel_transpose_16;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_q));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_q));
+
+        size_t local_size_q[3] = {4, 16, 1};
+        size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1};
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+
+        // scales
+        int height_s = M / 8;
+        int width_s = K / 32 / 8;
+
+        kernel = backend_ctx->kernel_transpose_16;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));
+
+        size_t local_size_s[3] = {4, 16, 1};
+        size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1};
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        // <----------------------------------------------------------------------------------> //
+
+        // copy transposed buffer contents to original buffers
+        // <----------------------------------------------------------------------------------> //
+        // weights
+        CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+
+        // scales
+        CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        // <----------------------------------------------------------------------------------> //
+
+        // deallocate transpose buffers
+        // <----------------------------------------------------------------------------------> //
+        CL_CHECK(clReleaseMemObject(qT_d));
+        CL_CHECK(clReleaseMemObject(dT_d));
+
+        // deallocate temporary images
+        CL_CHECK(clReleaseMemObject(q_d_image1D));
+        CL_CHECK(clReleaseMemObject(d_d_image1D));
+        CL_CHECK(clReleaseMemObject(qT_d_image1D));
+        CL_CHECK(clReleaseMemObject(dT_d_image1D));
+        // <----------------------------------------------------------------------------------> //
+        // end transpose
+        // <----------------------------------------------------------------------------------> //
+        }
+    #endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+        return;
+    }
+#endif // GGML_OPENCL_SOA_Q
+
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+    GGML_ASSERT(extra);
+
+    CL_CHECK(clEnqueueWriteBuffer(
+        queue, extra->data_device, CL_TRUE, extra->offset + offset,
+        size, data, 0, NULL, NULL));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->extra);
+
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
+
+    cl_context context = backend_ctx->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    // Make sure all previously submitted commands are finished.
+    CL_CHECK(clFinish(queue));
+
+#ifdef GGML_OPENCL_SOA_Q
+    // In end-to-end runs, get_tensor is usually used to get back the logits,
+    // where we can simply do clEnqueueReadBuffer since they are f32.
+    // However, in test-backend-ops, the GPU graph is copied to the CPU backend,
+    // which requires reading back quantized weight tensors.
+    // To properly support this, we need to restore block_q4_0 struct arrays
+    // from the flattened buffers.
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra;
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+
+        cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {1, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+            global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clEnqueueReadBuffer(
+            queue, data_device, CL_TRUE, offset,
+            size, data, 0, NULL, NULL));
+        CL_CHECK(clReleaseMemObject(data_device));
+        return;
+    }
+#endif // GGML_OPENCL_SOA_Q
+
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+
+    CL_CHECK(clEnqueueReadBuffer(
+        queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset,
+        size, data, 0, NULL, NULL));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_dev_t dev = buffer->buft->device;
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    for (cl_mem buf : ctx->buffer) {
+        CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
+    }
+    CL_CHECK(clFinish(queue));
+}
+
+static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    ctx->reset();
+}
+
+static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_opencl_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_opencl_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_opencl_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_opencl_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_opencl_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_opencl_buffer_clear,
+    /* .reset           = */ ggml_backend_opencl_buffer_reset,
+};
+
+//
+// buffer type
+//
+
+static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) {
+    return "OpenCL";
+
+    GGML_UNUSED(buffer_type);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device);
+
+    // clCreateBuffer returns -61 for size 0
+    size = std::max(size, (size_t)1);
+
+    cl_int err;
+    cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err);
+    if (err != CL_SUCCESS) {
+        GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
+        return nullptr;
+    }
+
+    ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem);
+
+    return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
+    // FIXME: not thread safe, device may not be initialized yet
+    static cl_uint alignment = -1;
+    if (alignment == (cl_uint)-1) {
+        ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
+        alignment = backend_ctx->alignment;
+    }
+    return alignment;
+}
+
+static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
+    static size_t max_size = -1;
+    if (max_size == (size_t)-1) {
+        ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
+        max_size = backend_ctx->max_alloc_size;
+    }
+    return max_size;
+}
+
+static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_opencl(backend);
+
+    UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_opencl_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_opencl_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_opencl_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_opencl_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ NULL,
+    /* .is_host          = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
+    static ggml_backend_buffer_type buffer_type = {
+        /* .iface   = */ ggml_backend_opencl_buffer_type_interface,
+        /* .device  = */ &g_ggml_backend_opencl_device,
+        /* .context = */ nullptr,
+    };
+
+    return &buffer_type;
+}
+
+//
+// backend device
+//
+
+static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) {
+    return "GPUOpenCL";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context;
+    return dev_ctx->device_name.c_str();
+}
+
+static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    *free = 1;
+    *total = 1;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_opencl_device_get_name(dev);
+    props->description = ggml_backend_opencl_device_get_description(dev);
+    props->type        = ggml_backend_opencl_device_get_type(dev);
+    ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = ggml_backend_dev_caps {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) {
+    ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev);
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_opencl_guid(),
+        /* .interface = */ ggml_backend_opencl_i,
+        /* .device    = */ dev,
+        /* .context   = */ backend_ctx,
+    };
+
+    return backend;
+
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_opencl_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    GGML_UNUSED(dev);
+    GGML_UNUSED(ptr);
+    GGML_UNUSED(size);
+    GGML_UNUSED(max_tensor_size);
+    return nullptr;
+}
+
+static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    return ggml_opencl_supports_op(dev, op);
+}
+
+static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name;
+
+    GGML_UNUSED(dev);
+}
+
+static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
+    /* .get_name             = */ ggml_backend_opencl_device_get_name,
+    /* .get_description      = */ ggml_backend_opencl_device_get_description,
+    /* .get_memory           = */ ggml_backend_opencl_device_get_memory,
+    /* .get_type             = */ ggml_backend_opencl_device_get_type,
+    /* .get_props            = */ ggml_backend_opencl_device_get_props,
+    /* .init_backend         = */ ggml_backend_opencl_device_init,
+    /* .get_buffer_type      = */ ggml_backend_opencl_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_opencl_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_opencl_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// Backend registry
+
+static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) {
+    return "OpenCL";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) {
+    return ggml_backend_opencl_n_devices;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    return &g_ggml_backend_opencl_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = {
+    /* .get_name         = */ ggml_backend_opencl_reg_get_name,
+    /* .device_count     = */ ggml_backend_opencl_reg_device_count,
+    /* .device_get       = */ ggml_backend_opencl_reg_device_get,
+    /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_opencl_reg(void) {
+    // TODO: make this thread-safe somehow?
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    if (!initialized) {
+        reg = ggml_backend_reg {
+            /* .api_version = */ GGML_BACKEND_API_VERSION,
+            /* .iface   = */ ggml_backend_opencl_reg_i,
+            /* .context = */ NULL,
+        };
+
+        g_ggml_backend_opencl_device = ggml_backend_device {
+            /* .iface   = */ ggml_backend_opencl_device_i,
+            /* .reg     = */ &reg,
+            /* .context = */ &g_ggml_ctx_dev_main,
+        };
+
+        ggml_cl2_init(&g_ggml_backend_opencl_device);
+
+        initialized = true;
+    }
+
+    return &reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg)
+
+//------------------------------------------------------------------------------
+// Debugging utils
+//------------------------------------------------------------------------------
+#if 0
+#define QK4_0 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2,
+    "wrong q4_0 block size/padding");
+
+#include <math.h>
+#ifdef __cplusplus
+#include "half.hpp"
+#endif
+
+static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) {
+    void * buf = malloc(ggml_nbytes(tensor));
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+#ifdef GGML_OPENCL_SOA_Q
+    void * buf_q;
+    void * buf_d;
+#endif
+
+#ifdef GGML_USE_OPENCL
+    // Make sure everything is done.
+    CL_CHECK(clFinish(queue));
+
+#ifdef GGML_OPENCL_SOA_Q
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra;
+        GGML_ASSERT(extra);
+
+        size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2;
+        size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t);
+        GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor));
+        buf_q = malloc(size_q);
+        buf_d = malloc(size_d);
+
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
+        CL_CHECK(clFinish(queue));
+    } else {
+        // Read out the tensor from GPU memory.
+        ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+        GGML_ASSERT(extra);
+
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,
+        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));
+        CL_CHECK(clFinish(queue));
+    }
+#else
+    // Read out the tensor from GPU memory.
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+    GGML_ASSERT(extra);
+
+    CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,
+        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));
+    CL_CHECK(clFinish(queue));
+#endif // GGML_OPENCL_SOA_Q
+#endif // GGML_USE_OPENCL
+
+    // Open file and dump.
+    char fname[512];
+    sprintf(fname, "./tensor-dumps/%s.txt", tensor->name);
+    FILE * f = fopen(fname, "w");
+    if (!f) {
+        printf("Failed to open %s\n", fname);
+        return;
+    }
+
+    if (tensor->type == GGML_TYPE_F32) {
+        float * data = (float *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%f\n", data[i]);
+        }
+    } else if (tensor->type == GGML_TYPE_I32) {
+        int * data = (int *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%d\n", data[i]);
+        }
+    } else if (tensor->type == GGML_TYPE_F16) {
+#ifdef __cplusplus
+        half_float::half * data = (half_float::half *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (std::isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%f\n", float(data[i]));
+        }
+#endif
+    } else if (tensor->type == GGML_TYPE_Q4_0) {
+#ifdef GGML_OPENCL_SOA_Q
+        ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d;
+        unsigned char * data_q = (unsigned char *)buf_q;
+
+        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {
+            fprintf(f, "%04x, ", data_d[i]);
+            for (int k = 0; k < QK4_0/2; ++k) {
+                fprintf(f, "%02x, ", data_q[k]);
+            }
+            fprintf(f, "\n");
+            data_q += QK4_0/2;
+        }
+        free(buf_d);
+        free(buf_q);
+#else
+        block_q4_0 * data = (block_q4_0 *) buf;
+        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {
+            fprintf(f, "%04x, ", data[i].d);
+            for (int k = 0; k < QK4_0/2; ++k) {
+                fprintf(f, "%02x, ", data[i].qs[k]);
+            }
+            fprintf(f, "\n");
+        }
+#endif // GGML_OPENCL_SOA_Q
+    }
+    free(buf);
+    fflush(f);
+    fclose(f);
+}
+#else
+#define dump_tensor(tensor)
+#endif
+
+//------------------------------------------------------------------------------
+// Profiling utility
+//------------------------------------------------------------------------------
+#ifdef GGML_OPENCL_PROFILING
+void populateProfilingInfo(
+        ProfilingInfo& info, cl_event evt, cl_kernel kernel,
+        size_t global_size[3], size_t local_size[3],
+        const ggml_tensor * tensor) {
+    cl_ulong start;
+    cl_ulong end;
+    CL_CHECK(clWaitForEvents(1, &evt));
+    CL_CHECK(clGetEventProfilingInfo(
+        evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
+    CL_CHECK(clGetEventProfilingInfo(
+        evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
+
+    char kernel_name[512];
+    CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
+        sizeof(kernel_name), kernel_name, NULL));
+
+    info.duration_ns = end - start;
+    info.op_name = tensor->name;
+    info.kernel_name = kernel_name;
+    info.local_size[0]  = local_size[0];
+    info.local_size[1]  = local_size[1];
+    info.local_size[2]  = local_size[2];
+    info.global_size[0] = global_size[0];
+    info.global_size[1] = global_size[1];
+    info.global_size[2] = global_size[2];
+    info.output_size[0] = tensor->ne[0];
+    info.output_size[1] = tensor->ne[1];
+    info.output_size[2] = tensor->ne[2];
+    info.output_size[3] = tensor->ne[3];
+}
+#endif
+
+//------------------------------------------------------------------------------
+// Ops
+//------------------------------------------------------------------------------
+
+static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+            src1->type == GGML_TYPE_F32 &&
+             dst->type == GGML_TYPE_F32 &&
+            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
+}
+
+static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    UNUSED(backend);
+    UNUSED(src0);
+    UNUSED(src1);
+    UNUSED(dst);
+}
+
+static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int      ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const int      ne10 = src1 ? src1->ne[0] : 0;
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const int      ne11 = src1 ? src1->ne[1] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb1  = dst  ?  dst->nb[1] : 0;
+    const cl_ulong nb2  = dst  ?  dst->nb[2] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            kernel = backend_ctx->kernel_get_rows_f32;
+            break;
+        case GGML_TYPE_F16:
+            kernel = backend_ctx->kernel_get_rows_f16;
+            break;
+        case GGML_TYPE_Q4_0:
+            kernel = backend_ctx->kernel_get_rows_q4_0;
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
+
+    size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
+    size_t local_work_size[] = {1, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int  ne10 = src1 ? src1->ne[0] : 0;
+    const int  ne11 = src1 ? src1->ne[1] : 0;
+    const int  ne12 = src1 ? src1->ne[2] : 0;
+    const int  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+    const int  ne0  = dst ? dst->ne[0] : 0;
+    const int  ne1  = dst ? dst->ne[1] : 0;
+    const int  ne2  = dst ? dst->ne[2] : 0;
+    const int  ne3  = dst ? dst->ne[3] : 0;
+
+    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
+    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
+    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
+    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    bool bcast_row = false;
+    cl_kernel kernel;
+
+    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+
+        // src1 is a row
+        GGML_ASSERT(ne11 == 1);
+
+        bcast_row = true;
+        int ne = ne00 / 4;
+        kernel = backend_ctx->kernel_add_row;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+    } else {
+        kernel = backend_ctx->kernel_add;
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+    }
+
+    if (bcast_row) {
+        int n = ggml_nelements(dst)/4;
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        unsigned int nth = MIN(64, ne0);
+        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const int ne01 = src0 ? src0->ne[1] : 0;
+    const int ne02 = src0 ? src0->ne[2] : 0;
+    const int ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0;
+    const int ne12 = src1 ? src1->ne[2] : 0;
+    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+    const int ne0  = dst ? dst->ne[0] : 0;
+    const int ne1  = dst ? dst->ne[1] : 0;
+    const int ne2  = dst ? dst->ne[2] : 0;
+    const int ne3  = dst ? dst->ne[3] : 0;
+
+    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
+    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
+    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
+    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    bool bcast_row = false;
+    cl_kernel kernel;
+
+    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+
+        // src1 is a row
+        GGML_ASSERT(ne11 == 1);
+
+        bcast_row = true;
+        int ne = ne00 / 4;
+        kernel = backend_ctx->kernel_mul_row;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+    } else {
+        kernel = backend_ctx->kernel_mul;
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+    }
+
+    if (bcast_row) {
+        int n = ggml_nelements(dst)/4;
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        unsigned int nth = MIN(64, ne0);
+        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    int n = ggml_nelements(dst);
+
+    if (n % 4 == 0) {
+        kernel = backend_ctx->kernel_gelu_4;
+        n /= 4;
+    } else {
+        kernel = backend_ctx->kernel_gelu;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt);
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL);
+#endif
+}
+
+static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    int n = ggml_nelements(dst);
+
+    if (n % 4 == 0) {
+        kernel = backend_ctx->kernel_silu_4;
+        n /= 4;
+    } else {
+        kernel = backend_ctx->kernel_silu;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_relu;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    const int64_t n = ggml_nelements(dst);
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float min;
+    float max;
+    memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+    memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
+
+    cl_kernel kernel = backend_ctx->kernel_clamp;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &min));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &max));
+
+    const int64_t n = ggml_nelements(dst);
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+    const int nth = MIN(64, ne00);
+
+    cl_kernel kernel = backend_ctx->kernel_norm;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &ne00));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float),     &eps));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL));
+
+    const int64_t nrows = ggml_nrows(src0);
+
+    size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_backend_opencl_device_context * dev_ctx =
+        (ggml_backend_opencl_device_context *)backend->device->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+
+    GGML_ASSERT(ne00 % 4 == 0);
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+    const int nth = MIN(64, ne00);
+
+    const int64_t nrows = ggml_nrows(src0);
+
+    size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    cl_kernel kernel = backend_ctx->kernel_rms_norm;
+
+    // Note, this kernel declares local memory in kernel args and the size
+    // depends on subgroup size.
+    // Retrieve subgroup size.
+    // Note, this requires OpenCL 2.1 and above
+    size_t sgs;
+    CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device,
+        CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
+        sizeof(local_work_size), local_work_size,
+        sizeof(size_t), &sgs, NULL));
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &ne00));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float),     &eps));
+    // This is local memory - the size depends on subgroup size.
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs,  NULL));
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+#ifdef GGML_OPENCL_SOA_Q
+    ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
+#endif
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int  ne10 = src1 ? src1->ne[0] : 0;
+    const int  ne11 = src1 ? src1->ne[1] : 0;
+    const int  ne12 = src1 ? src1->ne[2] : 0;
+    const int  ne13 = src1 ? src1->ne[3] : 0;
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;
+
+    const int  ne0 = dst ? dst->ne[0] : 0;
+    const int  ne1 = dst ? dst->ne[1] : 0;
+
+    int r2 = ne12/ne02;
+    int r3 = ne13/ne03;
+
+    GGML_ASSERT(ne00 == ne10);
+
+    int nth0 = 32;
+    int nth1 = 1;
+    int nrows = 1;
+    // The number of values produced by each subgroup
+    int ndst = 4;
+
+    cl_kernel kernel;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    cl_context context = backend_ctx->context;
+
+    if (ne01 && ne1 && use_adreno_kernels(src0)) {
+
+    // init CL objects
+    // <--------------------------------------------> //
+    cl_int              status;
+    cl_image_format     img_fmt_1d;
+    cl_image_desc       img_desc_1d;
+    cl_buffer_region    region;
+    cl_mem              A_image1d = nullptr;
+    cl_mem              B_image1d = nullptr;
+    cl_mem              B_sub_buffer = nullptr;
+    cl_mem              C_d = nullptr;
+    // for B transpose
+    cl_mem B_d = nullptr;
+    cl_mem B_d_input_image = nullptr;
+    // <--------------------------------------------> //
+
+    // define matrix dimensions
+    // <--------------------------------------------> //
+    int M = ne01;
+    int N = ne1;
+    int K = ne00;
+    int padding;
+    // <--------------------------------------------> //
+
+    // q4_0 x fp32
+    if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
+        // TODO: remove duplicate definitions of image description + format -- move to top
+
+        // create an image for A
+        // <--------------------------------------------> //
+        if (N == 1) {
+            img_fmt_1d = { CL_R, CL_UNSIGNED_INT32};
+        } else {
+            img_fmt_1d = { CL_R, CL_FLOAT};
+        }
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 2 / 4;    // Divide by 4 for char -> float
+        img_desc_1d.buffer = extra0_q4_0->q;
+        A_image1d = clCreateImage(
+            context,
+            CL_MEM_READ_ONLY,
+            &img_fmt_1d,
+            &img_desc_1d,
+            NULL,
+            &status);
+        CL_CHECK(status);
+        // <--------------------------------------------> //
+
+
+        // create a sub_buffer for B
+        // <--------------------------------------------> //
+        region.origin = (extra1->offset);
+        region.size = K * N * sizeof(float);
+        B_sub_buffer = clCreateSubBuffer(
+            extra1->data_device,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            &region,
+            &status);
+        CL_CHECK(status);
+        // <--------------------------------------------> //
+
+        // transpose activation for Skyler's gemm
+        if (N != 1) {
+            //how many extra elements beyond multiple of 8
+            int extra_elements = N % 8;
+
+            //how much padding to add
+            padding = 0;
+            if (extra_elements > 0){
+                padding = 8 - extra_elements;
+            }
+
+            // Specify the starting offset (in bytes)
+            region.origin = 0;
+            // Specify the size of the sub-buffer (divide by 2 for FP16)
+            region.size = K * (N + padding) * sizeof(float)/2;
+            B_d = clCreateSubBuffer(
+                backend_ctx->B_d_max,
+                0,
+                CL_BUFFER_CREATE_TYPE_REGION,
+                &region,
+                &status);
+            CL_CHECK(status);
+
+            cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT };
+            cl_image_desc image_desc_B_d_input = {
+                CL_MEM_OBJECT_IMAGE1D_BUFFER,
+                static_cast<size_t>(K * N / 4),
+                0, 0, 0, 0, 0, 0, 0, { B_sub_buffer }
+            };
+            B_d_input_image = clCreateImage(
+                context,
+                0,
+                &image_format_B_d_input,
+                &image_desc_B_d_input,
+                NULL,
+                &status);
+            CL_CHECK(status);
+
+            cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)
+            cl_image_desc image_desc_B_d_output = {
+                CL_MEM_OBJECT_IMAGE1D_BUFFER,
+                static_cast<size_t>(K * (N + padding)/4),
+                0, 0, 0, 0, 0, 0, 0, { B_d }
+            };
+            B_image1d = clCreateImage(
+                context,
+                0,
+                &image_format_B_d_output,
+                &image_desc_B_d_output,
+                NULL,
+                &status);
+            CL_CHECK(status);
+
+            int height_B = N/4;
+            int width_B = K/4;
+            int padded_height_B = (N + padding)/4;
+
+            kernel = backend_ctx->kernel_transpose_32_16;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));
+
+            size_t local_size_t[2] = { 1, 16 };
+            //WGS tuning
+            if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=4;
+                local_size_t[1]=8;
+            } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=2;
+                local_size_t[1]=8;
+            } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) {
+                local_size_t[0]=1;
+                local_size_t[1]=8;
+            } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=2;
+                local_size_t[1]=8;
+            }
+
+            size_t global_size_t[2] = {
+                static_cast<size_t>(width_B),
+                static_cast<size_t>(padded_height_B)
+            };
+
+            #ifdef GGML_OPENCL_PROFILING
+                cl_event evt;
+                CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, &evt));
+
+                g_profiling_info.emplace_back();
+                populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_size_t, local_size_t, dst);
+            #else
+                CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, NULL));
+            #endif
+        } else {
+            // no need to transpose B in other cases
+            // create an image for B from sub_buffer
+            // <--------------------------------------------> //
+            img_fmt_1d = {CL_RGBA, CL_FLOAT};
+
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_desc_1d.image_width = K * N / 4;
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.buffer = B_sub_buffer;
+            B_image1d = clCreateImage(
+                context,
+                CL_MEM_READ_ONLY,
+                &img_fmt_1d,
+                &img_desc_1d,
+                NULL,
+                &status);
+            CL_CHECK(status);
+            // <--------------------------------------------> //
+        }
+
+        // choose gemm or gemv kernel
+        // <--------------------------------------------> //
+        if (N == 1) {
+            kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;
+            if (M == 4096 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
+            } else if (M == 4096 && K == 11008) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;
+            } else if (M == 11008 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
+            } else if (M == 32000 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
+            }
+        } else {
+            kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4;
+        }
+        // <--------------------------------------------> //
+
+        // set kernel args
+        // <--------------------------------------------> //
+        cl_uint k_arg = 0;
+
+        if (N == 1) {
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &A_image1d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extra0_q4_0->d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &B_image1d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extra1->offset));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extrad->offset));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r3));
+        } else {
+            region.origin = extrad->offset; // Specify the starting offset (in bytes)
+            region.size = M * N * sizeof(float); // Specify the size of the sub-buffer
+            C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+            CL_CHECK(status);
+
+            int padded_N = ne1 + padding;
+
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &ne01)); //M
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),    &padded_N)); //N with padding
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),    &ne00)); //K
+            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),    &ne1)); //N without padding
+        }
+        // <--------------------------------------------> //
+
+        // choose workgroup size
+        // <--------------------------------------------> //
+        size_t global_work_size[3] = {
+            64, static_cast<size_t>((M+63)/64), static_cast<size_t>((N+31)/32)};
+        size_t local_work_size[3] = {64, 2, 4};
+
+        global_work_size[0] = (size_t)(ceil((float)ne1/8));
+        global_work_size[1] = (size_t)(ne01/4);
+        global_work_size[2] = (size_t)(1);
+
+        local_work_size[0]  = (size_t)(1); //4x32 for FP32
+        local_work_size[1]  = (size_t)(128);
+        local_work_size[2]  = (size_t)(1);
+
+        //WGS tuning
+        if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 1;
+            local_work_size[1] = 128;
+        } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        }
+
+        if (N == 1) {
+            local_work_size[0] = backend_ctx->adreno_wave_size; // localsize
+            local_work_size[1] = 4; // reduce factor
+            local_work_size[2] = 1;
+
+            global_work_size[0] = M / 2;
+            global_work_size[1] = 4; // reduce factor
+            global_work_size[2] = 1;
+        }
+        // <--------------------------------------------> //
+
+        // enqueue kernel with profiling
+        // <--------------------------------------------> //
+    #ifdef GGML_OPENCL_PROFILING
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+        // enqueue kernel without profiling
+    #else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+    #endif
+        // <--------------------------------------------> //
+
+        // deallocate sub buffers and images
+        // <--------------------------------------------> //
+        CL_CHECK(clReleaseMemObject(A_image1d));
+        CL_CHECK(clReleaseMemObject(B_sub_buffer));
+        CL_CHECK(clReleaseMemObject(B_image1d));
+
+        if (N != 1) {
+            CL_CHECK(clReleaseMemObject(B_d));
+            CL_CHECK(clReleaseMemObject(B_d_input_image));
+            CL_CHECK(clReleaseMemObject(C_d));
+        }
+        // <--------------------------------------------> //
+
+        return;
+    }
+    } // if (ne01 && ne1)
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    if (!ggml_is_transposed(src0) &&
+        !ggml_is_transposed(src1) &&
+        src1t == GGML_TYPE_F32 &&
+        ne00%32 == 0 &&
+        ne11 > 2) {
+#ifdef GGML_OPENCL_SOA_Q
+        // Set up kernel.
+        switch(src0t) {
+            case GGML_TYPE_Q4_0:
+                // This should have been satisfied.
+                GGML_ASSERT(ne11 == ne1);
+                GGML_ASSERT(ne01 == ne0);
+
+                if (backend_ctx->gpu_family == INTEL) {
+                    nth0 = 16;
+                    nth1 = 1;
+
+                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat;
+                } else if (backend_ctx->gpu_family == ADRENO) {
+                    nth0 = 64;
+                    nth1 = 1;
+
+                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat;
+                } else {
+                    GGML_ASSERT(false && "TODO: Unknown GPU");
+                }
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+                break;
+            default:
+                break;
+        }
+
+        // Launch kernel.
+        if (src0t == GGML_TYPE_Q4_0) {
+            size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+            size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+            if (backend_ctx->gpu_family == INTEL) {
+                // Set global size for Intel. It uses 16x output values.
+                global_work_size[0] = (size_t)(ne01 + 15)/16*nth0;
+                global_work_size[1] = (size_t)ne11*nth1;
+                global_work_size[2] = (size_t)ne12*ne13;
+            }
+
+#ifdef GGML_OPENCL_PROFILING
+            cl_event evt;
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+            g_profiling_info.emplace_back();
+            populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+            return;
+        }
+#else // GGML_OPENCL_SOA_Q
+        // TODO: add block_q4_0 variant.
+#endif // GGML_OPENCL_SOA_Q
+    }
+
+    // use custom matrix x vector kernel
+    switch (src0t) {
+        case GGML_TYPE_F32:
+            //GGML_ASSERT(ne02 == ne12);
+            GGML_ASSERT(src1t == GGML_TYPE_F32);
+            kernel = backend_ctx->kernel_mul_mat_f32_f32;
+            nrows = 4;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 32;
+                nth1 = 1;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));
+            break;
+        case GGML_TYPE_F16:
+            //GGML_ASSERT(ne02 == ne12);
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 32;
+                nth1 = 1;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            if (src1t == GGML_TYPE_F32) {
+                if (ne11 * ne12 < 4) {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32_1row;
+                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32_l4;
+                    nrows = ne11;
+                } else {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32;
+                    nrows = 4;
+                }
+            } else {
+                kernel = backend_ctx->kernel_mul_mat_f16_f16;
+                nrows = 4;
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));
+            break;
+        case GGML_TYPE_Q4_0:
+            // This should have been satisfied.
+            GGML_ASSERT(ne11 == ne1);
+            GGML_ASSERT(ne01 == ne0);
+
+#ifdef GGML_OPENCL_SOA_Q
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;
+                ndst = 8;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;
+                ndst =8;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#else // GGML_OPENCL_SOA_Q
+            if (backend_ctx->gpu_family == INTEL) {
+                // Use 1D local size. Each workgroup is a SIMD group. Each SIMD
+                // group produces N_DST (4 for Q4_0 kernel) values in the result.
+                // The number of workgroups on dim 0 (the leading dimension) is
+                // the nearest multiple of 4 that covers ne0 (equals ne01).
+                nth0 = 16;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#endif // GGML_OPENCL_SOA_Q
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 2;
+                nth1 = 16;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 2;
+                nth1 = 64;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    if (src0t == GGML_TYPE_Q4_0 ||
+        src0t == GGML_TYPE_Q4_1 ||
+        src0t == GGML_TYPE_Q8_0 ||
+        src0t == GGML_TYPE_Q2_K) {
+        // Each SIMD group produces N_DST values in the result. Assuming each
+        // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will
+        // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size
+        // (number of workgroups) will be a nearest multiple of
+        // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is
+        // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul).
+        size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else if (src0t == GGML_TYPE_Q4_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q3_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q5_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q6_K) {
+        size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        int64_t ny = (ne11 + nrows - 1)/nrows;
+
+        size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+    GGML_UNUSED(src1);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    float scale;
+    memcpy(&scale, dst->op_params, sizeof(scale));
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_scale;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &scale));
+
+    int n = ggml_nelements(dst)/4;
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+
+    // GGML_OP_CPY happens between src0 and src1.
+    // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst.
+    UNUSED(dst);
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const int ne01 = src0 ? src0->ne[1] : 0;
+    const int ne02 = src0 ? src0->ne[2] : 0;
+    const int ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0;
+    const int ne12 = src1 ? src1->ne[2] : 0;
+    const int ne13 = src1 ? src1->ne[3] : 0;
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+
+    cl_kernel kernel;
+
+    switch (src0t) {
+        case GGML_TYPE_F32:
+            switch (src1t) {
+                case GGML_TYPE_F16:
+                    kernel = backend_ctx->kernel_cpy_f32_f16;
+                    break;
+                case GGML_TYPE_F32:
+                    kernel = backend_ctx->kernel_cpy_f32_f32;
+                    break;
+                default:
+                    GGML_ASSERT(false && "not implemented");
+            }
+            break;
+        case GGML_TYPE_F16:
+            switch (src1t) {
+                case GGML_TYPE_F16:
+                    kernel = backend_ctx->kernel_cpy_f16_f16;
+                    break;
+                case GGML_TYPE_F32:
+                    kernel = backend_ctx->kernel_cpy_f16_f32;
+                    break;
+                default:
+                    GGML_ASSERT(false && "not implemented");
+            }
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne11));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+
+    const int nth = MIN(64, ne00);
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, src1);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cl_cpy(backend, src0, dst, nullptr);
+    UNUSED(src1);
+}
+
+static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    int n_past = ((int32_t *)(dst->op_params))[0];
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    if (ne00%8 == 0) {
+        kernel = backend_ctx->kernel_diag_mask_inf_8;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));
+
+        size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        kernel = backend_ctx->kernel_diag_mask_inf;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));
+
+        size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    // Softmax can now fuse KQ mask and KQ scale, which used to be two additional
+    // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama,
+    // alibi is not used; however, for some other models, it is used.
+    // KQ_mask
+    if (src1) {
+        GGML_ASSERT(src1);
+        GGML_ASSERT(src1->extra);
+    }
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    float scale, max_bias;
+    memcpy(&scale,    dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, dst->op_params + 1, sizeof(float));
+
+    const int nrows_x = ggml_nrows(src0);
+    const int nrows_y = src0->ne[1];
+
+    const int n_head      = nrows_x/nrows_y;
+    const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // Local size must be wave size. Each workgroup is a wave, working on a row,
+    // where a row corresponds to leading dimension.
+    int nth = MIN(32, ne00);
+
+    if (backend_ctx->gpu_family == INTEL) {
+        // This is the same as the initial value.
+        nth = MIN(32, ne00);
+    }
+    else if (backend_ctx->gpu_family == ADRENO) {
+        nth = 64;
+    } else {
+        GGML_ASSERT(false && "TODO: Unknown GPU");
+    }
+
+    cl_kernel kernel;
+
+    if (ne00%4 == 0) {
+        kernel = backend_ctx->kernel_soft_max_4;
+    } else {
+        kernel = backend_ctx->kernel_soft_max;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &n_head_log2));
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    ggml_tensor * src2 = dst->src[2];
+    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
+
+    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const int  nb00 = src0 ? src0->nb[0] : 0;
+    const int  nb01 = src0 ? src0->nb[1] : 0;
+    const int  nb02 = src0 ? src0->nb[2] : 0;
+    const int  nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11);
+    const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12);
+    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const int  ne0 = dst ? dst->ne[0] : 0;
+    const int  ne1 = dst ? dst->ne[1] : 0;
+    const int  ne2 = dst ? dst->ne[2] : 0;
+    const int  ne3 = dst ? dst->ne[3] : 0;
+
+    const int  nb0 = dst ? dst->nb[0] : 0;
+    const int  nb1 = dst ? dst->nb[1] : 0;
+    const int  nb2 = dst ? dst->nb[2] : 0;
+    const int  nb3 = dst ? dst->nb[3] : 0;
+
+    GGML_ASSERT(ne10 == ne02);
+
+    int nth = MIN(64, ne00);
+
+    const int n_past     = ((int *) dst->op_params)[0];
+    const int n_dims     = ((int *) dst->op_params)[1];
+    const int mode       = ((int *) dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    const bool is_neox = mode & 2;
+
+    cl_kernel kernel;
+
+    if (!is_neox) {
+        switch (src0->type) {
+            case GGML_TYPE_F32:
+                kernel = backend_ctx->kernel_rope_norm_f32;
+                break;
+            case GGML_TYPE_F16:
+                kernel = backend_ctx->kernel_rope_norm_f16;
+                break;
+            default:
+                GGML_ASSERT(false);
+        };
+    } else {
+        switch (src0->type) {
+            case GGML_TYPE_F32:
+                kernel = backend_ctx->kernel_rope_neox_f32;
+                break;
+            case GGML_TYPE_F16:
+                kernel = backend_ctx->kernel_rope_neox_f16;
+                break;
+            default:
+                GGML_ASSERT(false);
+        };
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_past));
+    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &n_dims));
+    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),      &n_ctx_orig));
+    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &freq_base));
+    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float),    &freq_scale));
+    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &ext_factor));
+    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor));
+    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast));
+    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow));
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Op offloading
+//------------------------------------------------------------------------------
+
+typedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) {
+    ggml_cl_func_t func = nullptr;
+
+    ggml_tensor * src0 = tensor->src[0];
+    ggml_tensor * src1 = tensor->src[1];
+
+    const bool any_on_device = tensor->extra
+        || (src0 != nullptr && src0->extra)
+        || (src1 != nullptr && src1->extra);
+
+    switch (tensor->op) {
+        case GGML_OP_GET_ROWS:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_get_rows;
+            break;
+        case GGML_OP_CPY:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_cpy;
+            break;
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_dup;
+            break;
+        case GGML_OP_ADD:
+            if (!any_on_device) {
+                return false;
+            }
+            GGML_ASSERT(ggml_is_contiguous(src0));
+            GGML_ASSERT(ggml_is_contiguous(src1));
+            func = ggml_cl_add;
+            break;
+        case GGML_OP_MUL:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_mul;
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_silu;
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_relu;
+                    break;
+                default:
+                    return false;
+            } break;
+        case GGML_OP_CLAMP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_clamp;
+            break;
+        case GGML_OP_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_norm;
+            break;
+        case GGML_OP_RMS_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_rms_norm;
+            break;
+        case GGML_OP_MUL_MAT:
+            if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
+                return false;
+            }
+            func = ggml_cl_mul_mat;
+            break;
+        case GGML_OP_SCALE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_scale;
+            break;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_nop;
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_diag_mask_inf;
+            break;
+        case GGML_OP_SOFT_MAX:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_soft_max;
+            break;
+        case GGML_OP_ROPE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_rope;
+            break;
+        default:
+            return false;
+    }
+
+    func(backend, tensor->src[0], tensor->src[1], tensor);
+    return true;
+}
diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py
new file mode 100644 (file)
index 0000000..b5d1d72
--- /dev/null
@@ -0,0 +1,26 @@
+#
+
+import sys
+import logging
+logger = logging.getLogger("opencl-embed-kernel")
+
+
+def main():
+    logging.basicConfig(level=logging.INFO)
+
+    if len(sys.argv) != 3:
+        logger.info("Usage: python embed_kernel.py <input_file> <output_file>")
+        sys.exit(1)
+
+    ifile = open(sys.argv[1], "r")
+    ofile = open(sys.argv[2], "w")
+
+    for i in ifile:
+        ofile.write('R"({})"\n'.format(i))
+
+    ifile.close()
+    ofile.close()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl
new file mode 100644 (file)
index 0000000..d1cdf70
--- /dev/null
@@ -0,0 +1,2683 @@
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q4_1
+//------------------------------------------------------------------------------
+struct block_q4_1
+{
+    half d;
+    half m;
+    uint8_t qs[QK4_1 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_0
+//------------------------------------------------------------------------------
+struct block_q5_0
+{
+    half d;
+    uint32_t qh;
+    uint8_t qs[QK5_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_1
+//------------------------------------------------------------------------------
+struct block_q5_1
+{
+    half d;
+    half m;
+    uint32_t qh;
+    uint8_t qs[QK5_1 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q8_0
+//------------------------------------------------------------------------------
+struct block_q8_0
+{
+    half d;
+    int8_t qs[QK8_0];
+};
+
+//------------------------------------------------------------------------------
+// block_q2_K
+//------------------------------------------------------------------------------
+struct block_q2_K
+{
+    uint8_t scales[16];
+    uint8_t qs[64];
+    half d;
+    half dmin;
+};
+
+//------------------------------------------------------------------------------
+// block_q3_K
+//------------------------------------------------------------------------------
+struct block_q3_K
+{
+    uint8_t hmask[32];
+    uint8_t qs[64];
+    uint8_t scales[12];
+    half d;
+};
+
+//------------------------------------------------------------------------------
+// block_q4_K
+//------------------------------------------------------------------------------
+struct block_q4_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qs[128];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_K
+//------------------------------------------------------------------------------
+struct block_q5_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qh[32];
+    uint8_t qs[128];
+};
+
+//------------------------------------------------------------------------------
+// block_q6_K
+//------------------------------------------------------------------------------
+struct block_q6_K
+{
+    uint8_t ql[128];
+    uint8_t qh[64];
+    int8_t scales[16];
+    half d;
+};
+
+//------------------------------------------------------------------------------
+// dequantize_q4_0_f32, dequantize_q4_0_f16
+//------------------------------------------------------------------------------
+void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
+    global ushort * qs = ((global ushort *)xb + 1);
+    float d1 = il ? (xb->d / 16.h) : xb->d;
+    float d2 = d1 / 256.f;
+    float md = -8.h * xb->d;
+    ushort mask0 = il ? 0x00F0 : 0x000F;
+    ushort mask1 = mask0 << 8;
+
+    reg->s0 = d1 * (qs[0] & mask0) + md;
+    reg->s1 = d2 * (qs[0] & mask1) + md;
+
+    reg->s2 = d1 * (qs[1] & mask0) + md;
+    reg->s3 = d2 * (qs[1] & mask1) + md;
+
+    reg->s4 = d1 * (qs[2] & mask0) + md;
+    reg->s5 = d2 * (qs[2] & mask1) + md;
+
+    reg->s6 = d1 * (qs[3] & mask0) + md;
+    reg->s7 = d2 * (qs[3] & mask1) + md;
+
+    reg->s8 = d1 * (qs[4] & mask0) + md;
+    reg->s9 = d2 * (qs[4] & mask1) + md;
+
+    reg->sa = d1 * (qs[5] & mask0) + md;
+    reg->sb = d2 * (qs[5] & mask1) + md;
+
+    reg->sc = d1 * (qs[6] & mask0) + md;
+    reg->sd = d2 * (qs[6] & mask1) + md;
+
+    reg->se = d1 * (qs[7] & mask0) + md;
+    reg->sf = d2 * (qs[7] & mask1) + md;
+}
+
+void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) {
+    global ushort * qs = ((global ushort *)xb + 1);
+    half d1 = il ? (xb->d / 16.h) : xb->d;
+    half d2 = d1 / 256.h;
+    half md = -8.h * xb->d;
+    ushort mask0 = il ? 0x00F0 : 0x000F;
+    ushort mask1 = mask0 << 8;
+
+    reg->s0 = d1 * (qs[0] & mask0) + md;
+    reg->s1 = d2 * (qs[0] & mask1) + md;
+
+    reg->s2 = d1 * (qs[1] & mask0) + md;
+    reg->s3 = d2 * (qs[1] & mask1) + md;
+
+    reg->s4 = d1 * (qs[2] & mask0) + md;
+    reg->s5 = d2 * (qs[2] & mask1) + md;
+
+    reg->s6 = d1 * (qs[3] & mask0) + md;
+    reg->s7 = d2 * (qs[3] & mask1) + md;
+
+    reg->s8 = d1 * (qs[4] & mask0) + md;
+    reg->s9 = d2 * (qs[4] & mask1) + md;
+
+    reg->sa = d1 * (qs[5] & mask0) + md;
+    reg->sb = d2 * (qs[5] & mask1) + md;
+
+    reg->sc = d1 * (qs[6] & mask0) + md;
+    reg->sd = d2 * (qs[6] & mask1) + md;
+
+    reg->se = d1 * (qs[7] & mask0) + md;
+    reg->sf = d2 * (qs[7] & mask1) + md;
+}
+
+//------------------------------------------------------------------------------
+// add
+//------------------------------------------------------------------------------
+
+// general-purpose kernel for addition of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+// cons: not very efficient
+kernel void kernel_add(
+        global char * src0,
+        ulong  offset0,
+        global char * src1,
+        ulong  offset1,
+        global char * dst,
+        ulong  offsetd,
+        int   ne00,
+        int   ne01,
+        int   ne02,
+        int   ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int   ne10,
+        int   ne11,
+        int   ne12,
+        int   ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int   ne0,
+        int   ne1,
+        int   ne2,
+        int   ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
+    }
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        global float4 * src0,
+        ulong  offset0,
+        global float4 * src1,
+        ulong  offset1,
+        global float4 * dst,
+        ulong  offsetd,
+        int ne
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] + src1[idx1];
+}
+
+//------------------------------------------------------------------------------
+// mul
+//------------------------------------------------------------------------------
+kernel void kernel_mul(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
+    }
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_mul_row(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * src1,
+        ulong offset1,
+        global float4 * dst,
+        ulong offsetd,
+        int ne
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] * src1[idx1];
+}
+
+//------------------------------------------------------------------------------
+// scale
+//------------------------------------------------------------------------------
+kernel void kernel_scale(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd,
+        float scale
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+    dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
+}
+
+//------------------------------------------------------------------------------
+// gelu
+//------------------------------------------------------------------------------
+#define GELU_COEF_A     0.044715f
+#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f
+
+kernel void kernel_gelu(
+    global float * src0,
+    ulong offset0,
+    global float * dst,
+    ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    float x = src0[get_global_id(0)];
+
+    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    global float4 * src0,
+    ulong offset0,
+    global float4 * dst,
+    ulong offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    float4 x = src0[get_global_id(0)];
+
+    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+//------------------------------------------------------------------------------
+// silu
+//------------------------------------------------------------------------------
+kernel void kernel_silu(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    float x = src0[get_global_id(0)];
+    dst[get_global_id(0)] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    float4 x = src0[get_global_id(0)];
+    dst[get_global_id(0)] = x / (1.0f + exp(-x));
+}
+
+//------------------------------------------------------------------------------
+// relu
+//------------------------------------------------------------------------------
+kernel void kernel_relu(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);
+}
+
+//------------------------------------------------------------------------------
+// clamp
+//------------------------------------------------------------------------------
+kernel void kernel_clamp(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        float min,
+        float max
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = src0[get_global_id(0)] < min ?
+        min :
+        (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);
+}
+
+//------------------------------------------------------------------------------
+// norm
+//------------------------------------------------------------------------------
+kernel void kernel_norm(
+        global void * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        float eps,
+        local float * sum
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    dst = (global void*)((global char*)dst + offsetd);
+
+    global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
+
+    // MEAN
+    // parallel sum
+    sum[get_local_id(0)] = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        sum[get_local_id(0)] += x[i00];
+    }
+    // reduce
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
+        if (get_local_id(0) < i) {
+            sum[get_local_id(0)] += sum[get_local_id(0) + i];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    float mean  = sum[0] / ne00;
+
+    // recenter and VARIANCE
+    barrier(CLK_LOCAL_MEM_FENCE);
+    global float * y = dst + get_group_id(0)*ne00;
+    sum[get_local_id(0)] = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        y[i00] = x[i00] - mean;
+        sum[get_local_id(0)] += y[i00] * y[i00];
+    }
+
+    // reduce
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
+        if (get_local_id(0) < i) {
+            sum[get_local_id(0)] += sum[get_local_id(0) + i];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    float variance = sum[0] / ne00;
+
+    float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+//------------------------------------------------------------------------------
+// rms_norm
+//------------------------------------------------------------------------------
+// This kernel depends on subgroup size.
+kernel void kernel_rms_norm(
+        global void * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        float eps,
+        local float * sum // Note, the size depends on number of subgroups
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
+    global float * x_scalar = (global float *) x;
+    float4 sumf = 0;
+    float all_sum = 0;
+
+    // parallel sum
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
+    all_sum = sub_group_reduce_add(all_sum);
+    if (get_sub_group_local_id() == 0) {
+        sum[get_sub_group_id()] = all_sum;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+    // broadcast
+    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+       if (get_local_id(0) < i) {
+           sum[get_local_id(0)] += sum[get_local_id(0) + i];
+       }
+    }
+    if (get_local_id(0) == 0) {
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
+            sum[0] += x_scalar[i];
+        }
+        sum[0] /= ne00;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    const float mean  = sum[0];
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
+    global float * y_scalar = (global float *) y;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        y[i00] = x[i00] * scale;
+    }
+    if (get_local_id(0) == 0) {
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
+            y_scalar[i00] = x_scalar[i00] * scale;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// diag_mask_inf kernels
+//------------------------------------------------------------------------------
+kernel void kernel_diag_mask_inf(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int n_past
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i02 = get_global_id(2);
+    int i01 = get_global_id(1);
+    int i00 = get_global_id(0);
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int n_past
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    int i = 2*get_global_id(0);
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int i4 = 4*i;
+    int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    int i01 = i4/(ne00);      i4 -= i01*ne00;
+    int i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        (&dst[i+1])[k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            (&dst[i])[k] = -INFINITY;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// softmax
+//------------------------------------------------------------------------------
+kernel void kernel_soft_max(
+        global float * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        float scale,
+        float max_bias,
+        float m0,
+        float m1,
+        int n_head_log2
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
+    global float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        int h = i02;
+
+        float base = h < n_head_log2 ? m0 : m1;
+        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float lmax = -INFINITY;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+    float max = sub_group_reduce_max(lmax);
+
+    // parallel sum
+    float lsum = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
+        lsum += exp_psrc0;
+        // Remember the result of exp here. exp is expensive, so we really do not
+        // wish to compute it twice.
+        pdst[i00] = exp_psrc0;
+    }
+
+    const float sum = sub_group_reduce_add(lsum);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        pdst[i00] /= sum;
+    }
+}
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_soft_max_4(
+        global float * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        float scale,
+        float max_bias,
+        float m0,
+        float m1,
+        int n_head_log2
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
+    global float4 * pdst4 = (global float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        int h = i02;
+
+        float base = h < n_head_log2 ? m0 : m1;
+        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+    float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
+
+    const float max = sub_group_reduce_max(lmax);
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+    float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
+
+    const float sum = sub_group_reduce_add(lsum);
+
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        pdst4[i00] /= sum;
+    }
+}
+
+//------------------------------------------------------------------------------
+// kernel_rope
+//------------------------------------------------------------------------------
+float rope_yarn_ramp(float low, float high, int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+float2 rope_yarn(
+    float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    return (float2)(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+float2 rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
+) {
+    // start and end correction dims
+    return (float2)(
+        max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
+        min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
+    );
+}
+
+kernel void kernel_rope_norm_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global float * src       = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            float x0 = src[0];
+            float x1 = src[1];
+
+            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_norm_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            float x0 = src[0];
+            float x1 = src[1];
+
+            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_neox_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
+
+            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_neox_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
+
+            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// cpy
+//------------------------------------------------------------------------------
+
+kernel void kernel_cpy_f16_f16(
+        global half * src0,
+        ulong offset0,
+        global half * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst = (global half*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f16_f32(
+        global half * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        global float * src0,
+        ulong offset0,
+        global half * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global half*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+//------------------------------------------------------------------------------
+// get_rows
+//------------------------------------------------------------------------------
+kernel void kernel_get_rows_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_q4_0(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int NL = 2;
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
+        float16 temp;
+        dequantize_q4_0_f32(
+            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
+        *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f32_f32
+//------------------------------------------------------------------------------
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F32_F32;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global float * x = (global float *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float * y = (global float *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global float4 * x4 = (global float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float  * y  = (global float  *) (src1 + offset_src1);
+            global float4 * y4 = (global float4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += (float) x4[i].s0 * y4[i].s0;
+                sumf += (float) x4[i].s1 * y4[i].s1;
+                sumf += (float) x4[i].s2 * y4[i].s2;
+                sumf += (float) x4[i].s3 * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (float) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f16
+//------------------------------------------------------------------------------
+#define N_F16_F16 4
+
+kernel void kernel_mul_mat_f16_f16(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3)
+{
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F16_F16;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half * x = (global half *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global half * y = (global half *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global half4 * x4 = (global half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global half  * y  = (global half  *) (src1 + offset_src1);
+            global half4 * y4 = (global half4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += (half) x4[i].s0 * y4[i].s0;
+                sumf += (half) x4[i].s1 * y4[i].s1;
+                sumf += (half) x4[i].s2 * y4[i].s2;
+                sumf += (half) x4[i].s3 * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (half) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32_1row
+//------------------------------------------------------------------------------
+kernel void kernel_mul_mat_f16_f32_1row(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global half  * x = (global half  *) (src0 + offset_src0);
+    global float * y = (global float *) (src1 + offset_src1);
+
+    float sumf = 0;
+    if (ne00 < 128) {
+        for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        global half4  * x4 = (global half4  *) x;
+        global float4 * y4 = (global float4 *) y;
+        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+            sumf += (float) x4[i].s0 * y4[i].s0;
+            sumf += (float) x4[i].s1 * y4[i].s1;
+            sumf += (float) x4[i].s2 * y4[i].s2;
+            sumf += (float) x4[i].s3 * y4[i].s3;
+        }
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) {
+                all_sum += (float) x[i] * y[i];
+            }
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32
+//------------------------------------------------------------------------------
+#define N_F16_F32 4
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_f16_f32(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F16_F32;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half * x = (global half *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float * y = (global float *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += convert_float(x[i]) * y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global half4 * x4 = (global half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float  * y  = (global float  *) (src1 + offset_src1);
+            global float4 * y4 = (global float4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += convert_float(x4[i].s0) * y4[i].s0;
+                sumf += convert_float(x4[i].s1) * y4[i].s1;
+                sumf += convert_float(x4[i].s2) * y4[i].s2;
+                sumf += convert_float(x4[i].s3) * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (float) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32_l4
+//------------------------------------------------------------------------------
+// Assumes row size (ne00) is a multiple of 4
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_f16_f32_l4(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int nrows = ne11;
+    int r0 = get_group_id(0);
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half4 * x4 = (global half4 *) (src0 + offset_src0);
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+        global float4 * y4 = (global float4 *) (src1 + offset_src1);
+
+        float sumf = 0;
+        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+            sumf += convert_float(x4[i].s0) * y4[i].s0;
+            sumf += convert_float(x4[i].s1) * y4[i].s1;
+            sumf += convert_float(x4[i].s2) * y4[i].s2;
+            sumf += convert_float(x4[i].s3) * y4[i].s3;
+        }
+
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32
+//------------------------------------------------------------------------------
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_4_0_dot_y(
+        global struct block_q4_0 * qb_curr,
+        float sumy,
+        private float * yl,
+        int il
+) {
+    float d = qb_curr->d;
+    float2 acc = 0.f;
+    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
+    for (int i = 0; i < 8; i+=2) {
+        acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc.s0 + acc.s1);
+}
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32(
+        global void * src0,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
+    // id of a SIMD group in the grid.
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
+    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];       // src1 vector cache
+    float sumf[N_DST]={0.f};
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
+        }
+
+        for (int row = 0; row < N_DST; row++) {
+            sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);
+        }
+
+        // One thread in a SIMD group (i.e., subgroup) handles a half block,
+        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
+        // y points to the activation matrix (of type float). Therefore for
+        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
+        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
+        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    // The above does not work for Adreno - it produces incorrect results for
+    // row = 1, 2, 3 and only row = 0 gives the correct result.
+    // If N_DST is changed, the below array must be initialized accordingly.
+    // This also seems to perform better on Intel.
+    float tot[N_DST] = {
+        sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
+        sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
+    for (int row = 0; row < N_DST; ++row) {
+        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//
+// This variant unrolls the loops and uses vector types instead of pointers.
+// It improves performance on Adreno but not so much on Intel.
+//
+inline float block_q_4_0_dot_y_v(
+        global struct block_q4_0 * qb_curr,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float d = qb_curr->d;
+    float acc = 0.f;
+    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_v(
+        global void * src0,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
+    // id of a SIMD group in the grid.
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
+    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;       // src1 vector cache
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);
+
+        // One thread in a SIMD group (i.e., subgroup) handles a half block,
+        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
+        // y points to the activation matrix (of type float). Therefore for
+        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
+        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
+        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    // The above does not work for Adreno - it produces incorrect results for
+    // row = 1, 2, 3 and only row = 0 gives the correct result.
+    // If N_DST is changed, the below array must be initialized accordingly.
+    // This also seems to perform better on Intel.
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_v(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//------------------------------------------------------------------------------
+// kernel_convert_block_q4_0
+// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
+// This kernel does not deshuffle the bits.
+//------------------------------------------------------------------------------
+kernel void kernel_convert_block_q4_0(
+    global struct block_q4_0 * src0,
+    global uchar * dst_q,
+    global half  * dst_d
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+
+    *d = b->d;
+
+    for (int i = 0; i < QK4_0/2; ++i) {
+        q[i] = b->qs[i];
+    }
+}
+
+kernel void kernel_restore_block_q4_0(
+    global uchar * src_q,
+    global half  * src_d,
+    global struct block_q4_0 * dst
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
+    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) src_d + get_global_id(0);
+
+    b->d = *d;
+    for (int i = 0; i < QK4_0/2; ++i) {
+        b->qs[i] = q[i];
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32_flat
+//
+// This variation uses flat arrays (struct of arrays, SOA) representation for
+// quant tensors.
+//------------------------------------------------------------------------------
+
+// This function requires the original shuffled weights.
+// As a reminder, the original weights are shuffled so that (q[0], q[16]) are
+// packed together in a byte, so are (q[1], q[17]) and so on.
+inline float block_q_4_0_dot_y_flat(
+        global uchar * x,
+        global half  * dh,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float           d   = *dh;
+    global ushort * qs  = ((global ushort *)x + il/2);
+    float           acc = 0.f;
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//
+// This variant outputs 8 values.
+//
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 8 // each SIMD group works on 8 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
+#elif defined (ADRENO_GPU)
+#define N_DST 8
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float8 sumf = 0.f;
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
+        sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
+        sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
+        sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float8 tot = (float8)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl
new file mode 100644 (file)
index 0000000..e202433
--- /dev/null
@@ -0,0 +1,106 @@
+//------------------------------------------------------------------------------
+// This file is contains additional kernels for data conversion.
+// These kernels are used when loading the model, so its performance is less
+// important.
+//------------------------------------------------------------------------------
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32_flat_noshuffle
+//
+// This variation uses flat arrays (struct of arrays, SOA) representation for
+// quant tensors. It also uses non shuffled bit order for weights.
+//
+// The shuffled version is kept in the original file because moving it here
+// seems to result in worse performance for adreno.
+//------------------------------------------------------------------------------
+
+kernel void kernel_convert_block_q4_0_noshuffle(
+    global struct block_q4_0 * src0,
+    global uchar * dst_q,
+    global half  * dst_d
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+
+    *d = b->d;
+    for (int i = 0; i < QK4_0/4; ++i) {
+        uchar x0 = b->qs[2*i + 0];
+        uchar x1 = b->qs[2*i + 1];
+
+        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
+        q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
+
+#ifdef ADRENO_GPU
+        // Workaround for adreno - must have the following printf statement for
+        // the kernel to work properly. Otherwise it produces incorrect result.
+        // convert_uchar above also seems necessary.
+        // Compare against a large number so that it does not print anything.
+        // get_sub_group_local_id() also works.
+        if (get_global_id(0) == 65536*4096) {
+            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
+        }
+#endif
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl
new file mode 100644 (file)
index 0000000..5e19541
--- /dev/null
@@ -0,0 +1,265 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
+#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+// assume
+#define QK4_0 32
+#define N_SIMDGROUP 4
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
+    float shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
+    float8 shared_y; \
+    shared_y = sub_group_broadcast(y, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+__kernel void kernel_gemv_noshuffle(
+        __read_only  image1d_buffer_t src0_q,  // quantized A
+        global half2  * src0_d,  // A scales
+        __read_only  image1d_buffer_t src1,    // B
+        ulong offset1,            // offset to B (0)
+        global float * dst,     // C
+        ulong offsetd,            // offset to C (0)
+        uint K,               // K
+        int ne01,               // M
+        int ne02,               // 1
+        int ne10,               // K
+        int ne12,               // 1
+        int ne0,                // M
+        int ne1,                // N
+        int r2,                 // 1
+        int r3)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    __private uint4     regA;
+    __private half2     regS;
+    __private float8    regB;
+
+    __private float2 totalSum = (float2)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load half weights for two blocks in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+    }
+
+    // reduction in local memory, assumes #wave=4
+    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
+    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
+    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
+    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 2 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        vstore2(totalSum, 0, &(dst[gid * 2]));
+    }
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl
new file mode 100644 (file)
index 0000000..5bdd4d0
--- /dev/null
@@ -0,0 +1,271 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
+#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+// assume
+#define QK4_0 32
+#define N_SIMDGROUP 4
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
+    float shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
+    float8 shared_y; \
+    shared_y = sub_group_broadcast(y, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+__kernel void kernel_gemv_noshuffle(
+        __read_only  image1d_buffer_t src0_q,  // quantized A
+        global half2  * src0_d,  // A scales
+        __read_only  image1d_buffer_t src1,    // B
+        ulong offset1,            // offset to B (0)
+        global float * dst,     // C
+        ulong offsetd,            // offset to C (0)
+        int ne00,               // K
+        int ne01,               // M
+        int ne02,               // 1
+        int ne10,               // K
+        int ne12,               // 1
+        int ne0,                // M
+        int ne1,                // N
+        int r2,                 // 1
+        int r3)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    uint K = ne00;
+    uint M = ne01;
+
+    uint LINE_STRIDE_A = M / 2;
+    uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
+
+    __private uint4     regA;
+    __private half2     regS;
+    __private float8    regB;
+
+    __private float2 totalSum = (float2)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load half weights for two blocks in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+    }
+
+    // reduction in local memory, assumes #wave=4
+    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
+    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
+    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
+    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 2 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        vstore2(totalSum, 0, &(dst[gid * 2]));
+    }
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl
new file mode 100644 (file)
index 0000000..e19e9a2
--- /dev/null
@@ -0,0 +1,1225 @@
+//------------------------------------------------------------------------------
+// This file is contains additional mulmat kernels
+// (and potentially other kernels).
+//------------------------------------------------------------------------------
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q6_K
+//------------------------------------------------------------------------------
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    half d;             // super-block scale
+} block_q6_K;
+
+//------------------------------------------------------------------------------
+// These are the variant for matmatmul, based on the matvecmul kernel with
+// flattened block_q4_0.
+//------------------------------------------------------------------------------
+
+// Common dot prod.
+inline float mm_block_q_4_0_dot_y_flat(
+        global uchar * x,
+        global half  * dh,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float           d   = *dh;
+    global ushort * qs  = ((global ushort *)x + il/2);
+    float           acc = 0.f;
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 8
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+//
+// This variant performs 1d blocking with 8x output.
+// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).
+//
+inline void mul_mat_q_n_f32_1d_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
+        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
+        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
+        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float8 tot = (float8)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 16
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+//
+// This variant performs 1d blocking with 16x output.
+// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).
+//
+inline void mul_mat_q_n_f32_1d_16x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+                             0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  0*nb*QK4_0/2, d + ib +  0*nb, sumy, yl, il);
+        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  1*nb*QK4_0/2, d + ib +  1*nb, sumy, yl, il);
+        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  2*nb*QK4_0/2, d + ib +  2*nb, sumy, yl, il);
+        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  3*nb*QK4_0/2, d + ib +  3*nb, sumy, yl, il);
+
+        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  4*nb*QK4_0/2, d + ib +  4*nb, sumy, yl, il);
+        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  5*nb*QK4_0/2, d + ib +  5*nb, sumy, yl, il);
+        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  6*nb*QK4_0/2, d + ib +  6*nb, sumy, yl, il);
+        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  7*nb*QK4_0/2, d + ib +  7*nb, sumy, yl, il);
+
+        sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  8*nb*QK4_0/2, d + ib +  8*nb, sumy, yl, il);
+        sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  9*nb*QK4_0/2, d + ib +  9*nb, sumy, yl, il);
+        sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);
+        sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);
+
+        sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);
+        sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);
+        sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);
+        sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float16 tot = (float16)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),
+
+        sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),
+        sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),
+        sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),
+        sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+
+        if (first_row + 8 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;
+        }
+        if (first_row + 9 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;
+        }
+        if (first_row + 10 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;
+        }
+        if (first_row + 11 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;
+        }
+
+        if (first_row + 12 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;
+        }
+        if (first_row + 13 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;
+        }
+        if (first_row + 14 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;
+        }
+        if (first_row + 15 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//------------------------------------------------------------------------------
+// kernel_mul_mat_q4_0_f32_flat_v0
+//------------------------------------------------------------------------------
+inline float block_q_4_0_dot_y_flat_v2(
+    half   x,
+    half   d,
+    float  sumy,
+    float4 yl
+) {
+    uchar2 q = as_uchar2(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+
+    acc += (q.s0 & 0xF0) * yl.s2;
+    acc += (q.s1 & 0xF0) * yl.s3;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+inline float block_q_4_0_dot_y_flat_v4(
+    float  x,
+    half   d,
+    float  sumy,
+    float8 yl
+) {
+    uchar4 q = as_uchar4(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+    acc += (q.s2 & 0x0F) * yl.s2;
+    acc += (q.s3 & 0x0F) * yl.s3;
+
+    acc += (q.s0 & 0xF0) * yl.s4;
+    acc += (q.s1 & 0xF0) * yl.s5;
+    acc += (q.s2 & 0xF0) * yl.s6;
+    acc += (q.s3 & 0xF0) * yl.s7;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+inline float block_q_4_0_dot_y_flat_v8(
+    float2  x,
+    half    d,
+    float   sumy,
+    float16 yl
+) {
+    uchar8 q = as_uchar8(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+    acc += (q.s2 & 0x0F) * yl.s2;
+    acc += (q.s3 & 0x0F) * yl.s3;
+    acc += (q.s4 & 0x0F) * yl.s4;
+    acc += (q.s5 & 0x0F) * yl.s5;
+    acc += (q.s6 & 0x0F) * yl.s6;
+    acc += (q.s7 & 0x0F) * yl.s7;
+
+    acc += (q.s0 & 0xF0) * yl.s8;
+    acc += (q.s1 & 0xF0) * yl.s9;
+    acc += (q.s2 & 0xF0) * yl.sa;
+    acc += (q.s3 & 0xF0) * yl.sb;
+    acc += (q.s4 & 0xF0) * yl.sc;
+    acc += (q.s5 & 0xF0) * yl.sd;
+    acc += (q.s6 & 0xF0) * yl.se;
+    acc += (q.s7 & 0xF0) * yl.sf;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define THREADS_PER_BLK 4   // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     16
+#elif defined (ADRENO_GPU)
+#define THREADS_PER_BLK 4
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     64
+#endif
+
+#if THREADS_PER_BLK == 2                // Each thread processes 1/2 block
+#   define ACT_TY                       float16
+#   define Q_BLK_LD_TY                  float2
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v8
+#elif THREADS_PER_BLK == 4              // Each thread processes 1/4 block
+#   define ACT_TY                       float8
+#   define Q_BLK_LD_TY                  float
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v4
+#elif THREADS_PER_BLK == 8              // Each thread processes 1/8 block
+#   define ACT_TY                       float4
+#   define Q_BLK_LD_TY                  half
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v2
+#endif
+
+#define BTYES_PER_THREAD_IN_BLK         (QK4_0/2/THREADS_PER_BLK)
+
+#if N_DST == 2
+#   define  SUM_TY                      float2
+#elif N_DST == 4
+#   define  SUM_TY                      float4
+#elif N_DST == 8
+#   define  SUM_TY                      float8
+#elif N_DST == 16
+#   define  SUM_TY                      float16
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat_v0(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    int ix = get_sub_group_local_id()/THREADS_PER_BLK;
+    int il = get_sub_group_local_id()%THREADS_PER_BLK;
+
+    global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
+
+    // Registers for caching activation
+    ACT_TY yl = 0.f;
+
+    // Registers for caching quants
+    Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
+#endif
+#if N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
+#endif
+
+    // Partial sum
+    SUM_TY sumf = 0.f;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
+        float sumy = 0.f;
+
+        q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2);
+        q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2);
+        q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2);
+#endif
+#if N_DST == 8 || N_DST == 16
+        q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2));
+        q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2));
+        q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2));
+        q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2));
+#endif
+
+        // Load activation
+#if THREADS_PER_BLK == 2    // Each thread processes 1/2 block
+        yl.s01234567 = *(global float8 *)(yb);
+        yl.s89abcdef = *(global float8 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4;
+        sumy += yl.s5;
+        sumy += yl.s6;
+        sumy += yl.s7;
+        sumy += yl.s8; yl.s8 /= 16.f;
+        sumy += yl.s9; yl.s9 /= 16.f;
+        sumy += yl.sa; yl.sa /= 16.f;
+        sumy += yl.sb; yl.sb /= 16.f;
+        sumy += yl.sc; yl.sc /= 16.f;
+        sumy += yl.sd; yl.sd /= 16.f;
+        sumy += yl.se; yl.se /= 16.f;
+        sumy += yl.sf; yl.sf /= 16.f;
+#elif THREADS_PER_BLK == 4  // Each thread processes 1/4 block
+        yl.s0123 = *(global float4 *)(yb);
+        yl.s4567 = *(global float4 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4; yl.s4 /= 16.f;
+        sumy += yl.s5; yl.s5 /= 16.f;
+        sumy += yl.s6; yl.s6 /= 16.f;
+        sumy += yl.s7; yl.s7 /= 16.f;
+#elif THREADS_PER_BLK == 8  // Each thread processes 1/8 block
+        yl.s01 = *(global float2 *)(yb);
+        yl.s23 = *(global float2 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2; yl.s2 /= 16.f;
+        sumy += yl.s3; yl.s3 /= 16.f;
+#endif
+
+        sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl);
+        sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl);
+        sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl);
+#endif
+#if N_DST == 8 || N_DST == 16
+        sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl);
+        sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl);
+        sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl);
+        sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl);
+#endif
+
+        yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
+    }
+
+    SUM_TY tot = (SUM_TY)(
+          sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+#endif
+#if N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
+        , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+#endif
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+#endif
+#if N_DST == 8 || N_DST == 16
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+#endif
+    }
+}
+
+//------------------------------------------------------------------------------
+// Using image1d_buffer_t
+
+#if defined(cl_qcom_subgroup_shuffle)
+#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
+float qcom_sub_group_reduce_add(float sum) {
+    sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    return sum;
+}
+#define sub_group_reduce_add qcom_sub_group_reduce_add
+#else
+#define sub_group_reduce_add sub_group_reduce_add
+#endif
+
+#undef THREADS_PER_BLK
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define THREADS_PER_BLK 4   // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     16
+#elif defined (ADRENO_GPU)
+#define THREADS_PER_BLK 4
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     64
+#endif
+
+#if THREADS_PER_BLK == 2                // Each thread processes 1/2 block
+#   define ACT_TY                       float16
+#   define Q_BLK_LD_TY                  float2
+#   define EXTRACT_BLK_DATA(tmp, part)  *((float2*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v8
+#elif THREADS_PER_BLK == 4              // Each thread processes 1/4 block
+#   define ACT_TY                       float8
+#   define Q_BLK_LD_TY                  float
+#   define EXTRACT_BLK_DATA(tmp, part)  *((float*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v4
+#elif THREADS_PER_BLK == 8              // Each thread processes 1/8 block
+#   define ACT_TY                       float4
+#   define Q_BLK_LD_TY                  half
+#   define EXTRACT_BLK_DATA(tmp, part)  *((half*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v2
+#endif
+
+#define BTYES_PER_THREAD_IN_BLK         (QK4_0/2/THREADS_PER_BLK)
+
+#if N_DST == 2
+#   define  SUM_TY                      float2
+#elif N_DST == 4
+#   define  SUM_TY                      float4
+#elif N_DST == 8
+#   define  SUM_TY                      float8
+#elif N_DST == 16
+#   define  SUM_TY                      float16
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat_img_v0(
+        read_only image1d_buffer_t src0_q,
+        read_only image1d_buffer_t src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    int ix = get_sub_group_local_id()/THREADS_PER_BLK;
+    int il = get_sub_group_local_id()%THREADS_PER_BLK;
+
+    global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
+
+    // Registers for caching activation
+    ACT_TY yl = 0.f;
+
+    // Registers for caching quants
+    Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
+#endif
+#if N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
+#endif
+
+    // Partial sum
+    SUM_TY sumf = 0.f;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
+        float sumy = 0.f;;
+
+        float4 tmp;
+        tmp = read_imagef(src0_q, offset0_q + ib + 0*nb);
+        q_blk_0 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 1*nb);
+        q_blk_1 = EXTRACT_BLK_DATA(tmp, il);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        tmp = read_imagef(src0_q, offset0_q + ib + 2*nb);
+        q_blk_2 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 3*nb);
+        q_blk_3 = EXTRACT_BLK_DATA(tmp, il);
+#endif
+#if N_DST == 8 || N_DST == 16
+        tmp = read_imagef(src0_q, offset0_q + ib + 4*nb);
+        q_blk_4 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 5*nb);
+        q_blk_5 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 6*nb);
+        q_blk_6 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 7*nb);
+        q_blk_7 = EXTRACT_BLK_DATA(tmp, il);
+#endif
+
+        // Load activation
+#if THREADS_PER_BLK == 2    // Each thread processes 1/2 block
+        yl.s01234567 = *(global float8 *)(yb);
+        yl.s89abcdef = *(global float8 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4;
+        sumy += yl.s5;
+        sumy += yl.s6;
+        sumy += yl.s7;
+        sumy += yl.s8; yl.s8 /= 16.f;
+        sumy += yl.s9; yl.s9 /= 16.f;
+        sumy += yl.sa; yl.sa /= 16.f;
+        sumy += yl.sb; yl.sb /= 16.f;
+        sumy += yl.sc; yl.sc /= 16.f;
+        sumy += yl.sd; yl.sd /= 16.f;
+        sumy += yl.se; yl.se /= 16.f;
+        sumy += yl.sf; yl.sf /= 16.f;
+#elif THREADS_PER_BLK == 4  // Each thread processes 1/4 block
+        yl.s0123 = *(global float4 *)(yb);
+        yl.s4567 = *(global float4 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4; yl.s4 /= 16.f;
+        sumy += yl.s5; yl.s5 /= 16.f;
+        sumy += yl.s6; yl.s6 /= 16.f;
+        sumy += yl.s7; yl.s7 /= 16.f;
+#elif THREADS_PER_BLK == 8  // Each thread processes 1/8 block
+        yl.s01 = *(global float2 *)(yb);
+        yl.s23 = *(global float2 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2; yl.s2 /= 16.f;
+        sumy += yl.s3; yl.s3 /= 16.f;
+#endif
+
+        sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl);
+        sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl);
+        sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl);
+#endif
+#if N_DST == 8 || N_DST == 16
+        sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl);
+        sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl);
+        sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl);
+        sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl);
+#endif
+
+        yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
+    }
+
+    SUM_TY tot = (SUM_TY)(
+          sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+#endif
+#if N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
+        , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+#endif
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+#endif
+#if N_DST == 8 || N_DST == 16
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+#endif
+    }
+}
+
+//------------------------------------------------------------------------------
+// kernel_mul_mv_q6_K_f32
+//------------------------------------------------------------------------------
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 1 // number of rows each SIMD group works on
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // SIMD group size
+#elif defined (ADRENO_GPU)
+#define N_DST 1
+#define N_SIMDGROUP 2
+#define N_SIMDWIDTH 64
+#endif
+
+#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q6_K_f32(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    uchar kmask1 = 0x03;
+    uchar kmask2 = 0x0C;
+    uchar kmask3 = 0x30;
+    uchar kmask4 = 0xC0;
+
+    int nb = ne00/QK_K;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int row = N_SIMDGROUP * r0 + get_sub_group_id();
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
+    global float      * yy = (global float     *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float sumf = 0;
+
+    // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
+    // block. Values in a subblock shares a scale that is quantized with 8 bits;
+    // the entire block shares a single floating point scale.
+    // For work distribution, each thread processes a subblock (16 weights), hence
+    // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
+    // (super) blocks -- this is the block stride.
+    // The 16 threads that process a (super) block are split into 2 portions, each has
+    // 8 threads; each portion works on 8 subblocks.
+    // For subgroup of 16 threads, the entire subgroup works on a single (super) block
+    // before moving to the next (super) block. Thread0 - thread7 work on the
+    // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
+    // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
+    // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
+    // works on a total of 16 weight values.
+    int tid  = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
+    int ix   = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
+    int ip   = tid/8;   // first or second half of (super) block (0 or 1)
+    int il   = tid%8;   // each half has 8 parts, one per scale
+    int n    = 4;       // 4 scales at a time (and 4 sums)
+    int l0   = n*il;    // offset into half-block, 0..28
+    int is   = 8*ip + l0/16; // 0, 1, 8, 9
+
+    int y_offset = 128*ip + l0;
+    int q_offset_l = 64*ip + l0;
+    int q_offset_h = 32*ip + l0;
+
+    for (int i = ix; i < nb; i += BLOCK_STRIDE) {
+
+        global uint8_t * q1 = x[i].ql + q_offset_l;
+        global uint8_t * q2 = q1 + QK_K/8;
+        global uint8_t * qh = x[i].qh + q_offset_h;
+        global int8_t  * sc = x[i].scales + is;
+
+        global float * y = yy + i * QK_K + y_offset;
+
+        float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+
+        sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
+
+        sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
+    }
+
+    float tot = sub_group_reduce_add(sumf);
+    if (get_sub_group_local_id() == 0) {
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
new file mode 100644 (file)
index 0000000..57768c8
--- /dev/null
@@ -0,0 +1,130 @@
+// src0_q, src0_d, src1 are transposed as a preprocessing step
+// 4-bit weights are transposed in groups of 4 (unsigned short int)
+// consider weights originally "next to each other", now "on top of each other"
+// each fiber computes a 8x4 tile of output elements
+// using unshuffled weights
+
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+kernel void kernel_mul_mat_Ab_Bi_8x4(
+        global const ushort * src0_q,       // quantized A
+        global const half  * src0_d,        // A scales
+        __read_only image1d_buffer_t src1,  // B (1d image)
+        global float * dst,                 // C
+        int m,                              // M
+        int n,                              // N with padding
+        int k,                              // K
+        int n_no_padding                    // N without padding
+) {
+
+    int m_4 = m >> 2;
+    int n_4 = n >> 2;
+
+    int gy = get_global_id(0);
+    int gx = get_global_id(1);
+    int gx_2 = gx << 2;
+
+    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
+    half8 B; // registers for activations
+    half4 dequantized_weights; // registers for dequantized weights
+    __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
+    __global const half* scale_ptr = src0_d + gx_2; // pointer for scales
+
+    for(int i=0; i<k; i+=4){ //loop through K dimension
+
+        B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
+
+        // keep (i/4) and (i/32) in parenthesis, rounds down
+        // load 4 consecutive groups of 4 weights
+        ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s
+
+        // load 4 consecutive scales
+        half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements
+
+        // j=0
+        dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;
+        dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;
+        dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=1
+        B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
+        dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=2
+        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
+        dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=3
+        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
+        dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+    }
+
+    int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
+
+    // conditional check if store is to a valid location. Required when N is not a multiple of 8
+    // if statements allow registers to be reused for each store
+    // provides a performance boost due to reduced register footprint, which increases number of concurrent waves
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl
new file mode 100644 (file)
index 0000000..d59a0c0
--- /dev/null
@@ -0,0 +1,32 @@
+// 16-bit transpose, loading/storing an 8x8 tile of elements
+
+kernel void kernel_transpose_16(
+    __read_only image1d_buffer_t input,
+    __write_only image1d_buffer_t output,
+    const uint rows,
+    const uint cols
+) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_3 = i<<3;
+    const int j_3 = j<<3;
+
+    ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i));
+    ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i));
+    ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i));
+    ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i));
+    ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i));
+    ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i));
+    ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i));
+    ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i));
+
+    write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0)));
+    write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1)));
+    write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2)));
+    write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3)));
+    write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4)));
+    write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5)));
+    write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6)));
+    write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7)));
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl
new file mode 100644 (file)
index 0000000..914ec01
--- /dev/null
@@ -0,0 +1,25 @@
+// 32-bit transpose, loading/storing a 4x4 tile of elements
+
+kernel void kernel_transpose_32(
+    __read_only image1d_buffer_t input,
+    __write_only image1d_buffer_t output,
+    const uint rows,
+    const uint cols
+) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_2 = i<<2;
+    const int j_2 = j<<2;
+
+    float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
+    float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
+    float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
+    float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
+
+    write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
+    write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
+    write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
+    write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl
new file mode 100644 (file)
index 0000000..d3bd1fa
--- /dev/null
@@ -0,0 +1,35 @@
+// 32-bit transpose, loading/storing a 4x4 tile of elements
+// Only used for activations
+// converts to FP16
+// also adds zero padding for non multiple of 8 prompt lengths
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_2 = i<<2;
+    const int j_2 = j<<2;
+    half4 temp0 = {0,0,0,0}; // initialize outputs to 0
+    half4 temp1 = {0,0,0,0};
+    half4 temp2 = {0,0,0,0};
+    half4 temp3 = {0,0,0,0};
+
+    if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
+        temp0 = read_imageh(input, (j_2+0)*cols+i);
+    }
+    if((j_2+1)*cols+i*4+3 < rows*cols*16){
+        temp1 = read_imageh(input, (j_2+1)*cols+i);
+    }
+    if((j_2+2)*cols+i*4+3 < rows*cols*16){
+        temp2 = read_imageh(input, (j_2+2)*cols+i);
+    }
+    if((j_2+3)*cols+i*4+3 < rows*cols*16){
+        temp3 = read_imageh(input, (j_2+3)*cols+i);
+    }
+
+    write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
+    write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
+    write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
+    write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
+}