]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (fused soft max, gpu cpy ops, etc.) (#640)
authorGeorgi Gerganov <redacted>
Thu, 7 Dec 2023 20:26:34 +0000 (22:26 +0200)
committerGitHub <redacted>
Thu, 7 Dec 2023 20:26:34 +0000 (22:26 +0200)
* sync : llama.cpp (fused soft max, gpu cpy ops, etc.)

ggml-ci

* cuda : restore accidentally deleted changes

ggml-ci

* cuda : fix rope + disable device-side dequantize

ggml-ci

* test-backend-ops : enable stablelm rope test

* cuda : remove rope assert

* sync.sh : add test-backend-ops

* ggml : fix ggml_concat + ggml_get_n_tasks logic

* sync : whisper.cpp

ggml-ci

* metal : fix assert

* ci : fix Metal path to shaders

ggml-ci

* whisper : fix bug if metal init fails

---------

Co-authored-by: slaren <redacted>
16 files changed:
.github/workflows/ci.yml
ci/run.sh
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
include/ggml/ggml.h
scripts/sync-llama.sh
src/ggml-alloc.c
src/ggml-cuda.cu
src/ggml-metal.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-opencl.cpp
src/ggml-quants.c
src/ggml.c
tests/test-backend-ops.cpp

index ba719f6540d146a5e567536dbd90efe9ddacd22e..4da1cd48a79c8344b89ed3b687a55784549bb206 100644 (file)
@@ -45,6 +45,7 @@ jobs:
         llvm-profdata merge -sparse tests/*.profraw -o ggml.profdata
         llvm-cov      report ./bin/test-grad0 -instr-profile=ggml.profdata
         llvm-cov      report ./bin/test-opt   -instr-profile=ggml.profdata
+
   test-macos-metal:
     runs-on: macos-13
     env:
index 61fbef19e8b788578045a3436eadef85e23a5c58..299da671be30f0705412c4faad382e5a900e0318 100644 (file)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -375,6 +375,11 @@ ret=0
 
 test $ret -eq 0 && gg_run ctest_debug
 test $ret -eq 0 && gg_run ctest_release
+
+if [ ! -z ${GG_BUILD_METAL} ]; then
+    export GGML_METAL_PATH_RESOURCES="${SRC}/build-ci-release/bin"
+fi
+
 test $ret -eq 0 && gg_run gpt_2
 test $ret -eq 0 && gg_run mnist
 test $ret -eq 0 && gg_run whisper
index 98af5839ca551fe25f15b65d8c37b6c19d0672f1..9699802e023d46279f3c4f3c69b87053cd72a395 100644 (file)
@@ -165,8 +165,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
-        else if (arg == "-ls"   || arg == "--log-score")       { params.log_score = true; }
-        else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu = false; }
+        else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
+        else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
index 376712157a93dd194c04364ce02a07592b998285..e709e29fd5006ad9e1149fef1287c91e7619a526 100644 (file)
@@ -1077,6 +1077,10 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
         backend_gpu = ggml_backend_metal_init();
         if (!backend_gpu) {
             WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
+        } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
+            WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
+            ggml_backend_free(backend_gpu);
+            backend_gpu = NULL;
         }
     }
 #endif
@@ -1611,24 +1615,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 // read into a temporary buffer first, then copy to device memory
                 read_buf.resize(ggml_nbytes(tensor));
 
-                // we repeat the 2 bias tensors along dim 0:
-                // [1, 512] -> [3000, 512] (conv1.bias)
-                // [1, 512] -> [1500, 512] (conv2.bias)
-                if (false) {
-                    loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
-
-                    float * data_f32 = (float *) read_buf.data();
-                    for (int64_t y = 0; y < tensor->ne[1]; ++y) {
-                        const int64_t yy = tensor->ne[1] - y - 1;
-                        const float val = data_f32[yy];
-
-                        for (int64_t x = 0; x < tensor->ne[0]; ++x) {
-                            data_f32[yy*tensor->ne[0] + x] = val;
-                        }
-                    }
-                } else {
-                    loader->read(loader->context, read_buf.data(), read_buf.size());
-                }
+                loader->read(loader->context, read_buf.data(), read_buf.size());
 
                 ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
             }
@@ -3513,7 +3500,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
 int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
     whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
 
-    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+    whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
 
     if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
@@ -3526,19 +3513,10 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
     if (ctx->state == nullptr) {
         WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
-        return false;
-    }
-
-    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
-
-    whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
-
-    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
-        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
-        return 1;
+        return -1;
     }
 
-    return 0;
+    return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
 }
 
 int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
@@ -3590,6 +3568,17 @@ const char * whisper_lang_str(int id) {
     return nullptr;
 }
 
+const char * whisper_lang_str_full(int id) {
+   for (const auto & kv : g_lang) {
+        if (kv.second.first == id) {
+            return kv.second.second.c_str();
+        }
+    }
+
+    WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
+    return nullptr;
+}
+
 int whisper_lang_auto_detect_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -5174,10 +5163,10 @@ int whisper_full_with_state(
             const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
 
             params.progress_callback(
-                ctx, ctx->state, progress_cur, params.progress_callback_user_data);
+                ctx, state, progress_cur, params.progress_callback_user_data);
         }
 
-        // of only 1 second left, then stop
+        // if only 1 second left, then stop
         if (seek + 100 >= seek_end) {
             break;
         }
@@ -6061,6 +6050,43 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
     // 1GB array
     const size_t size = arr*1e6;
 
+    double sum  = 0.0;
+
+    // heat-up
+    {
+        char * src = (char *) malloc(size);
+        char * dst = (char *) malloc(size);
+
+        for (size_t i = 0; i < size; i++) src[i] = i;
+
+        memcpy(dst, src, size); // heat-up
+
+        double tsum = 0.0;
+
+        for (size_t i = 0; i < n; i++) {
+            const int64_t t0 = ggml_time_us();
+
+            memcpy(dst, src, size);
+
+            const int64_t t1 = ggml_time_us();
+
+            tsum += (t1 - t0)*1e-6;
+
+            src[rand() % size] = rand() % 256;
+        }
+
+        snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
+        s += strbuf;
+
+        // needed to prevent the compiler from optimizing the memcpy away
+        {
+            for (size_t i = 0; i < size; i++) sum += dst[i];
+        }
+
+        free(src);
+        free(dst);
+    }
+
     // single-thread
     {
         char * src = (char *) malloc(size);
@@ -6071,7 +6097,6 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
         memcpy(dst, src, size); // heat-up
 
         double tsum = 0.0;
-        double sum  = 0.0;
 
         for (size_t i = 0; i < n; i++) {
             const int64_t t0 = ggml_time_us();
@@ -6085,21 +6110,73 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
             src[rand() % size] = rand() % 256;
         }
 
-        snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
+        snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
         s += strbuf;
 
         // needed to prevent the compiler from optimizing the memcpy away
         {
             for (size_t i = 0; i < size; i++) sum += dst[i];
+        }
 
-            snprintf(strbuf, sizeof(strbuf), "sum:    %f\n", sum);
-            s += strbuf;
+        free(src);
+        free(dst);
+    }
+
+    // multi-thread
+
+    for (uint32_t k = 1; k <= n_threads; k++) {
+        char * src = (char *) malloc(size);
+        char * dst = (char *) malloc(size);
+
+        for (size_t i = 0; i < size; i++) src[i] = i;
+
+        memcpy(dst, src, size); // heat-up
+
+        double tsum = 0.0;
+
+        auto helper = [&](int th) {
+            const int64_t i0 = (th + 0)*size/k;
+            const int64_t i1 = (th + 1)*size/k;
+
+            for (size_t i = 0; i < n; i++) {
+                memcpy(dst + i0, src + i0, i1 - i0);
+
+                src[i0 + rand() % (i1 - i0)] = rand() % 256;
+            };
+        };
+
+        const int64_t t0 = ggml_time_us();
+
+        std::vector<std::thread> threads(k - 1);
+        for (uint32_t th = 0; th < k - 1; ++th) {
+            threads[th] = std::thread(helper, th);
+        }
+
+        helper(k - 1);
+
+        for (uint32_t th = 0; th < k - 1; ++th) {
+            threads[th].join();
+        }
+
+        const int64_t t1 = ggml_time_us();
+
+        tsum += (t1 - t0)*1e-6;
+
+        snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
+        s += strbuf;
+
+        // needed to prevent the compiler from optimizing the memcpy away
+        {
+            for (size_t i = 0; i < size; i++) sum += dst[i];
         }
 
         free(src);
         free(dst);
     }
 
+    snprintf(strbuf, sizeof(strbuf), "sum:    %f\n", sum);
+    s += strbuf;
+
     return s.c_str();
 }
 
index 84540989d2ed339fb1ef4e834fa222af44e49e7b..3143ceaaf180b65dcd4d892c43b130f1f309c8a3 100644 (file)
@@ -50,7 +50,9 @@ extern "C" {
     //
     //     ...
     //
-    //     struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
+    //     whisper_context_params cparams = whisper_context_default_params();
+    //
+    //     struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams);
     //
     //     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
     //         fprintf(stderr, "failed to process audio\n");
@@ -313,6 +315,9 @@ extern "C" {
     // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
     WHISPER_API const char * whisper_lang_str(int id);
 
+    // Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
+    WHISPER_API const char * whisper_lang_str_full(int id);
+
     // Use mel data at offset_ms to try and auto-detect the spoken language
     // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
     // Returns the top language id or negative on failure
index 353d52e1987ce72f1aa7ce731f957a84366c057d..a8f10cbd5c1d8ead7963644753cd3d8bdd776f05 100644 (file)
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
-            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
-            fflush(stderr); \
             fflush(stdout); \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             ggml_print_backtrace(); \
-            exit(1); \
+            abort(); \
         } \
     } while (0)
 
@@ -1312,6 +1311,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // fused soft_max(a*scale + mask)
+    // mask is optional
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
     GGML_API struct ggml_tensor * ggml_soft_max_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -2090,6 +2097,7 @@ extern "C" {
     GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
     GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
     GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
     GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
index 6a04b3f270a9858d3d1d463fadec3512290b6dbd..67143c912120c35da8ea8d6484a32ea9f9559a14 100755 (executable)
@@ -24,3 +24,4 @@ cp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp
 cp -rpv ../llama.cpp/tests/test-grad0.cpp         tests/test-grad0.cpp
 cp -rpv ../llama.cpp/tests/test-quantize-fns.cpp  tests/test-quantize-fns.cpp
 cp -rpv ../llama.cpp/tests/test-quantize-perf.cpp tests/test-quantize-perf.cpp
+cp -rpv ../llama.cpp/tests/test-backend-ops.cpp   tests/test-backend-ops.cpp
index 80f2a4975f457f75badeb37250c83233eb9e02c9..d3049efb497a0a09a2c0b2b94f45629f8531e0a7 100644 (file)
@@ -135,10 +135,9 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
         ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
     }
 
-
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
-    size_t cur_max = (char*)addr - (char*)alloc->data + size;
+    size_t cur_max = (char*)addr - (char*)alloc->base + size;
     if (cur_max > alloc->max_size) {
         printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
index 1d2f8cacd9c33c6c0b4727552fdd6a1215189364..85f7a293783bed95758dd729fb49fb6c97d27017 100644 (file)
@@ -1,6 +1,7 @@
 #include <algorithm>
 #include <cstddef>
 #include <cstdint>
+#include <cinttypes>
 #include <float.h>
 #include <limits>
 #include <stdint.h>
@@ -237,7 +238,7 @@ typedef float2 dfloat2;
 #endif //GGML_CUDA_F16
 
 static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
-    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
 
     int x32 = 0;
     x32 |= x16[0] <<  0;
@@ -247,7 +248,7 @@ static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const
 }
 
 static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
-    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
 
     int x32 = 0;
     x32 |= x16[0] <<  0;
@@ -257,11 +258,11 @@ static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, con
 }
 
 static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
-    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
 }
 
 static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
-    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
 }
 
 template<typename T>
@@ -442,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -469,7 +471,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
 #define MAX_STREAMS 8
-static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -500,6 +502,31 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    }
+    return x;
+}
+
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
 }
@@ -623,23 +650,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
-    }
-    return x;
-}
-
-static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
-    }
-    return a;
-}
-
 template <int block_size>
 static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -2293,6 +2303,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
@@ -2304,7 +2315,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-
+    (void)x_qh; (void)x_sc;
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
     GGML_CUDA_ASSUME(k >= 0);
@@ -2313,7 +2324,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_0;
     const int kqsx = k % QI4_0;
 
-    const block_q4_0 * bx0 = (block_q4_0 *) vx;
+    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
 
     float * x_dmf = (float *) x_dm;
 
@@ -2351,9 +2362,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const float * x_dmf = (float *) x_dm;
+    const float * x_dmf = (const float *) x_dm;
 
     int u[2*VDR_Q4_0_Q8_1_MMQ];
 
@@ -2387,6 +2399,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
@@ -2398,6 +2411,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2407,7 +2421,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_1;
     const int kqsx = k % QI4_1;
 
-    const block_q4_1 * bx0 = (block_q4_1 *) vx;
+    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2442,6 +2456,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 
@@ -2479,6 +2494,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
@@ -2490,6 +2506,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2499,7 +2516,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_0;
     const int kqsx = k % QI5_0;
 
-    const block_q5_0 * bx0 = (block_q5_0 *) vx;
+    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2554,6 +2571,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
@@ -2593,6 +2611,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
@@ -2604,6 +2623,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset < nwarps);
@@ -2613,7 +2633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_1;
     const int kqsx = k % QI5_1;
 
-    const block_q5_1 * bx0 = (block_q5_1 *) vx;
+    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2665,6 +2685,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
@@ -2699,6 +2720,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
@@ -2710,6 +2732,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2720,7 +2743,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kqsx = k % QI8_0;
     float * x_dmf = (float *) x_dm;
 
-    const block_q8_0 * bx0 = (block_q8_0 *) vx;
+    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2755,6 +2778,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
@@ -2788,6 +2812,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
@@ -2801,6 +2826,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2810,7 +2836,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI2_K;
     const int kqsx = k % QI2_K;
 
-    const block_q2_K * bx0 = (block_q2_K *) vx;
+    const block_q2_K * bx0 = (const block_q2_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2858,6 +2884,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const int kbx = k / QI2_K;
     const int ky  = (k % QI2_K) * QR2_K;
@@ -2931,7 +2958,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI3_K;
     const int kqsx = k % QI3_K;
 
-    const block_q3_K * bx0 = (block_q3_K *) vx;
+    const block_q3_K * bx0 = (const block_q3_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3012,7 +3039,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
 
-    const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 
     int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
 
@@ -3127,6 +3154,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
@@ -3140,6 +3168,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3149,7 +3178,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_K; // == 0 if QK_K == 256
     const int kqsx = k % QI4_K; // == k if QK_K == 256
 
-    const block_q4_K * bx0 = (block_q4_K *) vx;
+    const block_q4_K * bx0 = (const block_q4_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3194,7 +3223,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
 
-        const int * scales = (int *) bxi->scales;
+        const int * scales = (const int *) bxi->scales;
 
         const int ksc = k % (WARP_SIZE/8);
 
@@ -3209,6 +3238,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 
@@ -3308,6 +3338,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
@@ -3321,6 +3352,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3330,7 +3362,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_K; // == 0 if QK_K == 256
     const int kqsx = k % QI5_K; // == k if QK_K == 256
 
-    const block_q5_K * bx0 = (block_q5_K *) vx;
+    const block_q5_K * bx0 = (const block_q5_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3386,7 +3418,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
 
-        const int * scales = (int *) bxi->scales;
+        const int * scales = (const int *) bxi->scales;
 
         const int ksc = k % (WARP_SIZE/8);
 
@@ -3401,6 +3433,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
 
@@ -3437,6 +3470,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
@@ -3450,6 +3484,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3459,7 +3494,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI6_K; // == 0 if QK_K == 256
     const int kqsx = k % QI6_K; // == k if QK_K == 256
 
-    const block_q6_K * bx0 = (block_q6_K *) vx;
+    const block_q6_K * bx0 = (const block_q6_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3521,6 +3556,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
@@ -3563,7 +3599,7 @@ static __device__ __forceinline__ void mul_mat_q(
     __shared__ int    tile_y_qs[mmq_x * WARP_SIZE];
     __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
 
-    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f};
+    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
 
     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
 
@@ -4568,6 +4604,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
+
+        dsti->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+                                 const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i02 = i / (ne00*ne01);
+    const int i01 = (i - i02*ne01*ne00) / ne00;
+    const int i00 = (i - i02*ne01*ne00 - i01*ne00);
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+    const int i12 = i / (ne10*ne11);
+    const int i11 = (i - i12*ne10*ne11) / ne10;
+    const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -4628,8 +4774,8 @@ static __global__ void rope(
 
 template<typename T, bool has_pos>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
 ) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
@@ -4638,23 +4784,25 @@ static __global__ void rope_neox(
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col/2;
+    const int ib = col / n_dims;
+    const int ic = col % n_dims;
+
+    const int i = row*ncols + ib*n_dims + ic/2;
     const int i2 = row/p_delta_rows;
 
-    // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
-    const float cur_rot = -float(col)/ncols;
+    float cur_rot = inv_ndims * ic - ib;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, cur_rot);
+    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
-    const float x1 = x[i + ncols/2];
+    const float x1 = x[i + n_dims/2];
 
-    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
-    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
 static __global__ void rope_glm_f32(
@@ -4793,45 +4941,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-// the CUDA soft max implementation differs from the CPU implementation
-// instead of doubles floats are used
-static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int block_size = blockDim.y;
-    const int tid = threadIdx.y;
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+    const int block_size = blockDim.x;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
 
     float max_val = -INFINITY;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        max_val = max(max_val, x[i]);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
     }
 
     // find the max value in the block
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = max_val;
+        }
+        __syncthreads();
+
+        max_val = buf[lane_id];
+        max_val = warp_reduce_max(max_val);
     }
 
     float tmp = 0.f;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        const float val = expf(x[i] - max_val);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
         tmp += val;
-        dst[i] = val;
+        dst[ix] = val;
     }
 
-    // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = 0.f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = tmp;
+        }
+        __syncthreads();
+
+        tmp = buf[lane_id];
+        tmp = warp_reduce_sum(tmp);
     }
 
     const float inv_tmp = 1.f / tmp;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
+        const int i = rowx*ncols + col;
         dst[i] *= inv_tmp;
     }
 }
@@ -5047,13 +5224,13 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __host__ __device__ void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 template<typename dst_t>
-static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -5063,7 +5240,7 @@ static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t
 }
 
 template<typename dst_t>
-static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -5073,13 +5250,13 @@ static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t
 }
 
 template<typename dst_t>
-static __host__ __device__ void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -5089,7 +5266,7 @@ static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t
 }
 
 template<typename dst_t>
-static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -5098,7 +5275,7 @@ static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t
 #endif
 }
 
-static to_fp16_cuda_t __host__ __device__ ggml_get_to_fp16_cuda(ggml_type type) {
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
             return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
@@ -5835,6 +6012,39 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f32_q8_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -5877,20 +6087,26 @@ static void rope_cuda(
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
 ) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.0f / n_dims;
+
     if (pos == nullptr) {
         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     } else {
         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     }
 }
@@ -5943,10 +6159,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
-static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
-    const dim3 block_dims(1, WARP_SIZE, 1);
+static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
 static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6499,6 +6717,7 @@ static int64_t get_row_rounding(ggml_type type) {
         case GGML_TYPE_Q8_0:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         case GGML_TYPE_F16:
+        case GGML_TYPE_F32:
             return 1;
         case GGML_TYPE_Q2_K:
             return max_compute_capability >= CC_RDNA2 ? 128 : 32;
@@ -6521,6 +6740,7 @@ static int64_t get_row_rounding(ggml_type type) {
         case GGML_TYPE_Q8_0:
             return 64;
         case GGML_TYPE_F16:
+        case GGML_TYPE_F32:
             return 1;
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
@@ -6540,6 +6760,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
+    GGML_ASSERT(ggml_nrows(src1) == 1);
+
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
@@ -6599,7 +6821,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
     size_t ash;
     dfloat * src1_dfloat = nullptr; // dfloat == half
 
-    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+    bool src1_convert_f16 =
+        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
         src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
 
@@ -6821,15 +7044,14 @@ inline void ggml_cuda_op_rope(
         GGML_ASSERT(false);
         rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
-        GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda(
-                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda(
-                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else {
@@ -6989,14 +7211,18 @@ inline void ggml_cuda_op_soft_max(
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
     const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
 
-    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+    float scale = 1.0f;
+    memcpy(&scale, dst->op_params, sizeof(float));
+
+    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
-    (void) src1;
     (void) dst;
-    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_scale(
@@ -7202,10 +7428,9 @@ static void ggml_cuda_op_mul_mat(
 
     const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
-
     const bool src1_is_contiguous = ggml_is_contiguous(src1);
-    const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
-        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+
+    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
     GGML_ASSERT(!(split && ne02 > 1));
@@ -7267,7 +7492,7 @@ static void ggml_cuda_op_mul_mat(
         if (src0_on_device && src0_is_contiguous) {
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
-            const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
+            // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
             src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
@@ -7330,7 +7555,7 @@ static void ggml_cuda_op_mul_mat(
                 const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
                 float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -7799,10 +8024,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
-            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
 #endif // GGML_CUDA_FORCE_DMMV
 
             if (use_mul_mat_vec_q) {
+                // NOTE: this kernel does not support ggml_nrows(src1) > 1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@@ -8084,14 +8310,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -8102,6 +8331,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
 }
 
 static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    // TODO: why do we pass dst as src1 here?
     ggml_cuda_cpy(src0, dst, nullptr);
     (void) src1;
 }
@@ -8434,7 +8664,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
     if (tensor->op == GGML_OP_MUL_MAT) {
         if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
+            fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = " PRId64 ", src1->ne[3] = " PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
 #endif
             return false;
         }
index 14ab5029cf8f7ad7354452b20103cfb3fe91a232..bf52d9cd34da48246be9f343b89f4557b557296b 100644 (file)
@@ -98,9 +98,13 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
+GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
 GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
 
-GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
 
 #ifdef __cplusplus
 }
index 3b43d5f6b68a511187e5a73b8fbfd085894f6c5b..f9bd69dc84bbe78128c5b9a01917decc40b627f0 100644 (file)
@@ -134,6 +134,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
@@ -230,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
             NSString * sourcePath;
             NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+            GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
+
             if (ggmlMetalPathResources) {
                 sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
             } else {
@@ -378,6 +386,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
@@ -473,6 +486,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
@@ -606,11 +624,11 @@ bool ggml_metal_add_buffer(
             ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
             if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
+                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
                 return false;
             }
 
-            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
+            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
 
             ++ctx->n_buffers;
         } else {
@@ -630,11 +648,11 @@ bool ggml_metal_add_buffer(
                 ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
                 if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
+                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
                     return false;
                 }
 
-                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
+                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
                 if (i + size_step < size) {
                     GGML_METAL_LOG_INFO("\n");
                 }
@@ -645,8 +663,8 @@ bool ggml_metal_add_buffer(
 
 #if TARGET_OS_OSX
         GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
-                ctx->device.currentAllocatedSize / 1e6,
-                ctx->device.recommendedMaxWorkingSetSize / 1e6);
+                ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
+                ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
             GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
@@ -654,7 +672,7 @@ bool ggml_metal_add_buffer(
             GGML_METAL_LOG_INFO("\n");
         }
 #else
-        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
+        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
 #endif
     }
 
@@ -782,7 +800,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
                 default:
                     return false;
             }
-            break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -812,7 +829,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_GET_ROWS:
             {
                 return op->ne[0] % 4 == 0;
-            } break;
+            }
         default:
             return false;
     }
@@ -1162,20 +1179,27 @@ void ggml_metal_graph_compute(
                             int nth = 32; // SIMD width
 
                             if (ne00%4 == 0) {
+                                while (nth < ne00/4 && nth < 256) {
+                                    nth *= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
                             } else {
-                                do {
+                                while (nth < ne00 && nth < 1024) {
                                     nth *= 2;
-                                } while (nth <= ne00 && nth <= 1024);
-                                nth /= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max];
                             }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+
+                            const float scale = ((float *) dst->op_params)[0];
+
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
+                            [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
+                            [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4];
+                            [encoder setBytes:&ne02  length:sizeof(ne02)  atIndex:5];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1245,7 +1269,7 @@ void ggml_metal_graph_compute(
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
                                 ne00 % 32 == 0 && ne00 >= 64 &&
-                                ne11 > ne11_mm_min) {
+                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
@@ -1553,15 +1577,19 @@ void ggml_metal_graph_compute(
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
-                            const int nth = MIN(512, ne00);
+                            int nth = 32; // SIMD width
+
+                            while (nth < ne00/4 && nth < 1024) {
+                                nth *= 2;
+                            }
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1635,7 +1663,8 @@ void ggml_metal_graph_compute(
                             const int n_past     = ((int32_t *) dst->op_params)[0];
                             const int n_dims     = ((int32_t *) dst->op_params)[1];
                             const int mode       = ((int32_t *) dst->op_params)[2];
-                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
+                            // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
 
                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1760,14 +1789,23 @@ void ggml_metal_graph_compute(
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                         {
-                            const int nth = MIN(1024, ne00);
+                            GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                            int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
 
                             switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
+                                        GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
                                         switch (dstt) {
-                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
-                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16];  break;
+                                            case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];  break;
+                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
+                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
+                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
+                                            //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
+                                            //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
@@ -2057,6 +2095,14 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     ggml_metal_set_n_cb(ctx, n_cb);
 }
 
+bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+}
+
 ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
 
 ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
index 2def5f8202bae069560f0cb6b2fa975d312455fe..2f8ea22d66226040f9b9b1de4b70466994f6ff94 100644 (file)
@@ -3,6 +3,7 @@
 using namespace metal;
 
 #define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
 #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
 
 #define QK4_0 32
@@ -40,6 +41,8 @@ typedef struct {
     int8_t  qs[QK8_0]; // quants
 } block_q8_0;
 
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
 enum ggml_sort_order {
     GGML_SORT_ASC,
     GGML_SORT_DESC,
@@ -328,10 +331,12 @@ kernel void kernel_gelu(
 
 kernel void kernel_soft_max(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -342,73 +347,77 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
+    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+    float lmax = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    max = buf[0];
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float lsum = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp(psrc0[i00] - max);
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
-        // Remember the result of exp here. exp is expensive, so we really do not
-        // wish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
 
-    sum = buf[0];
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        pdst[i00] /= sum;
+        pdst[i00] *= inv_sum;
     }
 }
 
 kernel void kernel_soft_max_4(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -419,64 +428,68 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+    float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]);
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
 
-    max = buf[0];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp(psrc4[i00] - max);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    sum = buf[0];
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        pdst4[i00] /= sum;
+        pdst4[i00] *= inv_sum;
     }
 }
 
@@ -583,14 +596,13 @@ kernel void kernel_rms_norm(
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
         constant     float & eps,
-        threadgroup float  * sum [[threadgroup(0)]],
+        threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x        = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-    device const float  * x_scalar = (device const float  *) x;
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 
     float4 sumf = 0;
     float all_sum = 0;
@@ -601,40 +613,30 @@ kernel void kernel_rms_norm(
     }
     all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
     all_sum = simd_sum(all_sum);
-    if (tiisg == 0) {
-        sum[sgitg] = all_sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           sum[tpitg] += sum[tpitg + i];
-       }
-    }
-    if (tpitg == 0) {
-        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
-            sum[0] += x_scalar[i];
+        if (tiisg == 0) {
+            buf[sgitg] = all_sum;
         }
-        sum[0] /= ne00;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    const float mean  = sum[0];
+        all_sum = buf[tiisg];
+        all_sum = simd_sum(all_sum);
+    }
+
+    const float mean  = all_sum/ne00;
     const float scale = 1.0f/sqrt(mean + eps);
 
     device float4 * y = (device float4 *) (dst + tgpig*ne00);
-    device float * y_scalar = (device float *) y;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
-    if (tpitg == 0) {
-        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
-            y_scalar[i00] = x_scalar[i00] * scale;
-        }
-    }
 }
 
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -724,7 +726,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4        // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
 //      giard against the number of rows not being divisible by
@@ -1724,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
     }
 }
 
+kernel void kernel_cpy_f32_q8_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
 kernel void kernel_concat(
     device const char * src0,
     device const char * src1,
index 202bcb4853893c7332906418059c9111ddedda8d..496f9cdca542d348e8edc4c08e0c25e6fefe6fe2 100644 (file)
@@ -1,20 +1,18 @@
+#include "ggml.h"
 #include "ggml-opencl.h"
 
 #include <array>
 #include <atomic>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <limits>
 #include <sstream>
 #include <vector>
-#include <limits>
 
 #define CL_TARGET_OPENCL_VERSION 110
 #include <clblast.h>
 
-#include <stdlib.h>
-#include <stdio.h>
-#include <string.h>
-
-#include "ggml.h"
-
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
index cf2860b8cbd5924c3da459a5feb78541dd120323..7285d5f7fbcc00ce41b5d481b702ecc52c5671f6 100644 (file)
@@ -19,7 +19,7 @@
 #ifdef __wasm_simd128__
 #include <wasm_simd128.h>
 #else
-#ifdef __POWER9_VECTOR__
+#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
 #include <altivec.h>
 #undef bool
 #define bool _Bool
index 26c86c42f58ffaf8736689a3940b9c37eeb79f07..ca56f063c3a87440353e3efce428c18e003517fa 100644 (file)
@@ -4884,7 +4884,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
 static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
         bool                  inplace) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+    }
+
     bool is_node = false;
 
     if (a->grad) {
@@ -4893,9 +4903,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    float params[] = { scale };
+    ggml_set_op_params(result, params, sizeof(params));
+
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
+    result->src[1] = mask;
 
     return result;
 }
@@ -4903,13 +4917,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
 struct ggml_tensor * ggml_soft_max(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, false);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
 }
 
 struct ggml_tensor * ggml_soft_max_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, true);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale) {
+    return ggml_soft_max_impl(ctx, a, mask, scale, false);
 }
 
 // ggml_soft_max_back
@@ -8437,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -8446,7 +8469,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(nb10 == sizeof(float));
 
     for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ith; i2 < ne2; i2++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
             if (i2 < ne02) { // src0
                 for (int i1 = 0; i1 < ne1; i1++) {
                     for (int i0 = 0; i0 < ne0; i0++) {
@@ -8808,6 +8831,7 @@ static void ggml_compute_forward_gelu_f32(
     // row range for this thread
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
+
     for (int i1 = ir0; i1 < ir1; i1++) {
         ggml_vec_gelu_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
@@ -9483,7 +9507,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     // TODO: find the optimal values for these
     if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
-        src0->type == GGML_TYPE_F32 &&
+      //src0->type == GGML_TYPE_F32 &&
         src1->type == GGML_TYPE_F32 &&
         (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
 
@@ -9743,10 +9767,12 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
     GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
     GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9757,18 +9783,25 @@ static void ggml_compute_forward_out_prod_f32(
     // GGML_ASSERT(nb1 <= nb2);
     // GGML_ASSERT(nb2 <= nb3);
 
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
     // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
-    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+    // TODO: #if defined(GGML_USE_CLBLAST)
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    bool use_blas = ggml_is_matrix(src0) &&
+        ggml_is_matrix(src1) &&
+        ggml_is_contiguous(src0) &&
+        (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
+#endif
 
     if (params->type == GGML_TASK_INIT) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
+        if (use_blas) {
+            return;
+        }
+#endif
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
         return;
     }
@@ -9777,6 +9810,50 @@ static void ggml_compute_forward_out_prod_f32(
         return;
     }
 
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (use_blas) {
+        if (params->ith != 0) { // All threads other than the first do no work.
+            return;
+        }
+        // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+        // src0: (k,n)
+        // src1: (k,m)
+        // dst:  (m,n)
+        //
+        // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+        // Also expressed as (major,minor)
+        // a: (m,k): so src1 transposed
+        // b: (k,n): so src0
+        // c: (m,n)
+        //
+        // However, if ggml_is_transposed(src1) is true, then
+        // src1->data already contains a transposed version, so sgemm mustn't
+        // transpose it further.
+
+        int n = src0->ne[0];
+        int k = src0->ne[1];
+        int m = src1->ne[0];
+
+        int transposeA, lda;
+
+        if (!ggml_is_transposed(src1)) {
+            transposeA = CblasTrans;
+            lda = m;
+        } else {
+            transposeA = CblasNoTrans;
+            lda = k;
+        }
+
+        float * a = (float *) ((char *) src1->data);
+        float * b = (float *) ((char *) src0->data);
+        float * c = (float *) ((char *) dst->data);
+
+        cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+        return;
+    }
+#endif
+
     // dst[:,:,:,:] = 0
     // for i2,i3:
     //   for i1:
@@ -10630,20 +10707,25 @@ static void ggml_compute_forward_diag_mask_zero(
 static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float scale = 1.0f;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
     const int nth = params->nth;
 
+    const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
     const int nc = src0->ne[0];
     const int nr = ggml_nrows(src0);
 
@@ -10654,29 +10736,40 @@ static void ggml_compute_forward_soft_max_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
     for (int i1 = ir0; i1 < ir1; i1++) {
-        float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp) {
+            ggml_vec_acc_f32(nc, wp, mp);
+        }
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(sp[i]));
+            assert(!isnan(wp[i]));
         }
 #endif
 
         float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, sp);
+        ggml_vec_max_f32(nc, &max, wp);
 
         ggml_float sum = 0.0;
 
         uint16_t scvt;
         for (int i = 0; i < nc; i++) {
-            if (sp[i] == -INFINITY) {
+            if (wp[i] == -INFINITY) {
                 dp[i] = 0.0f;
             } else {
-                // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
-                ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
+                // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
                 sum += (ggml_float)val;
@@ -10701,11 +10794,12 @@ static void ggml_compute_forward_soft_max_f32(
 static void ggml_compute_forward_soft_max(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_f32(params, src0, dst);
+                ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -14007,7 +14101,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX:
             {
-                ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
+                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
@@ -15749,7 +15843,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
         case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX_BACK:
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK:
@@ -15765,6 +15858,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = 1; //TODO
             } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
+            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 n_tasks = n_threads;
@@ -15858,7 +15955,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         default:
             {
-                printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
+                fprintf(stderr, "%s: op not implemented: ", __func__);
+                if (node->op < GGML_OP_COUNT) {
+                    fprintf(stderr, "%s\n", ggml_op_name(node->op));
+                } else {
+                    fprintf(stderr, "%d\n", node->op);
+                }
                 GGML_ASSERT(false);
             } break;
     }
@@ -15999,18 +16101,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
     // thread scheduling for the different operations + work buffer size estimation
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        int n_tasks = 1;
-
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
         size_t cur = 0;
 
         switch (node->op) {
             case GGML_OP_CPY:
             case GGML_OP_DUP:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     }
@@ -16018,16 +16118,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_ADD:
             case GGML_OP_ADD1:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
             case GGML_OP_ACC:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
@@ -16072,12 +16168,14 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_OUT_PROD:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
+            case GGML_OP_SOFT_MAX:
+                {
+                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -16103,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
                 } break;
-            case GGML_OP_IM2COL:
-                {
-                    n_tasks = n_threads;
-                } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
                     const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -16123,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
 
                     if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16137,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_FF:
                 {
-                    n_tasks = n_threads;
-
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16149,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN_BACK:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t    D = node->src[0]->ne[0];
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
                     const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16165,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
             case GGML_OP_CROSS_ENTROPY_LOSS:
                 {
-                    n_tasks = n_threads;
-
                     cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                 } break;
             case GGML_OP_COUNT:
@@ -18632,24 +18718,29 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
 }
 
 const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     return ctx->kv[key_id].key.data;
 }
 
 enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     return ctx->kv[key_id].type;
 }
 
 enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.type;
 }
 
 const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.data;
 }
 
 const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     struct gguf_kv * kv = &ctx->kv[key_id];
     struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
@@ -18657,70 +18748,90 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
 }
 
 int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.n;
 }
 
 uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
     return ctx->kv[key_id].value.uint8;
 }
 
 int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
     return ctx->kv[key_id].value.int8;
 }
 
 uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
     return ctx->kv[key_id].value.uint16;
 }
 
 int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
     return ctx->kv[key_id].value.int16;
 }
 
 uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
     return ctx->kv[key_id].value.uint32;
 }
 
 int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
     return ctx->kv[key_id].value.int32;
 }
 
 float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
     return ctx->kv[key_id].value.float32;
 }
 
 uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
     return ctx->kv[key_id].value.uint64;
 }
 
 int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
     return ctx->kv[key_id].value.int64;
 }
 
 double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
     return ctx->kv[key_id].value.float64;
 }
 
 bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
     return ctx->kv[key_id].value.bool_;
 }
 
 const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
     return ctx->kv[key_id].value.str.data;
 }
 
+const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
+    return &ctx->kv[key_id].value;
+}
+
 int gguf_get_n_tensors(const struct gguf_context * ctx) {
     return ctx->header.n_tensors;
 }
index a5110c663e2a6ce8c05f0450e007eea3aae57c39..e0155ac1c89139367a6a09e77b2f8a1cc3501ab0 100644 (file)
@@ -1239,7 +1239,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512)); // neox (falcon 7B)
         test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
         test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
-        //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) (TODO: enable after llama.cpp sync)
+        test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512)); // neox (stablelm)
     }
 
     test_cases.emplace_back(new test_alibi());