]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kompute : improve backend to pass test_backend_ops (#10542)
authorSergio López <redacted>
Thu, 28 Nov 2024 11:51:38 +0000 (12:51 +0100)
committerGitHub <redacted>
Thu, 28 Nov 2024 11:51:38 +0000 (12:51 +0100)
* kompute: op_unary: reject unsupported parameters

Signed-off-by: Sergio Lopez <redacted>
* kompute: softmax: implement ALiBi support

Signed-off-by: Sergio Lopez <redacted>
* kompute: rope: implement neox and phi3 support

Signed-off-by: Sergio Lopez <redacted>
* kompute: op_mul_mat_q4_k permutted support

Signed-off-by: Sergio Lopez <redacted>
* kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support

Signed-off-by: Sergio Lopez <redacted>
* kompute: op_mul_mat_f16 permutted support

Signed-off-by: Sergio Lopez <redacted>
* kompute: op_mul_mat_q6_k permutted support

Signed-off-by: Sergio Lopez <redacted>
---------

Signed-off-by: Sergio Lopez <redacted>
16 files changed:
ggml/src/ggml-kompute/CMakeLists.txt
ggml/src/ggml-kompute/ggml-kompute.cpp
ggml/src/ggml-kompute/kompute-shaders/common.comp
ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp
ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp
ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp
ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp
ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp
ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp [deleted file]
ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp [deleted file]
ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp [new file with mode: 0644]
ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp [new file with mode: 0644]
ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp [new file with mode: 0644]
ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp [new file with mode: 0644]
ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp
ggml/src/ggml-kompute/kompute-shaders/rope_common.comp

index dc623926c76850cc63d2504d2a1c890bb6cd4504..c9109d5e8ee191b7c16f14307cad8f303422523c 100644 (file)
@@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
         kompute-shaders/op_getrows_q4_0.comp
         kompute-shaders/op_getrows_q4_1.comp
         kompute-shaders/op_getrows_q6_k.comp
-        kompute-shaders/op_rope_f16.comp
-        kompute-shaders/op_rope_f32.comp
+        kompute-shaders/op_rope_norm_f16.comp
+        kompute-shaders/op_rope_norm_f32.comp
+        kompute-shaders/op_rope_neox_f16.comp
+        kompute-shaders/op_rope_neox_f32.comp
         kompute-shaders/op_cpy_f16_f16.comp
         kompute-shaders/op_cpy_f16_f32.comp
         kompute-shaders/op_cpy_f32_f16.comp
@@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
         shaderop_getrows_q4_0.h
         shaderop_getrows_q4_1.h
         shaderop_getrows_q6_k.h
-        shaderop_rope_f16.h
-        shaderop_rope_f32.h
+        shaderop_rope_norm_f16.h
+        shaderop_rope_norm_f32.h
+        shaderop_rope_neox_f16.h
+        shaderop_rope_neox_f32.h
         shaderop_cpy_f16_f16.h
         shaderop_cpy_f16_f32.h
         shaderop_cpy_f32_f16.h
index 24566404ded0fa3ed80ad6cb1b91c1d72727bdf9..28ceecfc40d666b391c8fff909059ccc4387c796 100644 (file)
 #include "shaderop_getrows_q4_0.h"
 #include "shaderop_getrows_q4_1.h"
 #include "shaderop_getrows_q6_k.h"
-#include "shaderop_rope_f16.h"
-#include "shaderop_rope_f32.h"
+#include "shaderop_rope_norm_f16.h"
+#include "shaderop_rope_norm_f32.h"
+#include "shaderop_rope_neox_f16.h"
+#include "shaderop_rope_neox_f32.h"
 #include "shaderop_cpy_f16_f16.h"
 #include "shaderop_cpy_f16_f32.h"
 #include "shaderop_cpy_f32_f16.h"
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
     std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
         vk::DescriptorPoolSize(
           vk::DescriptorType::eStorageBuffer,
-          3 * size // Descriptor count is number of possible tensors to pass into an algorithm
+          4 * size // Descriptor count is number of possible tensors to pass into an algorithm
           )
     };
 
@@ -788,7 +790,8 @@ static void ggml_vk_soft_max(
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
     int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
-    float scale
+    float scale, float max_bias, float m0, float m1,
+    uint32_t n_head_log2
 ) {
     const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
         kp::shader_data::op_softmax_comp_spv_len);
@@ -796,12 +799,14 @@ static void ggml_vk_soft_max(
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
         int32_t ne00, ne01, ne02;
-        float scale;
+        float scale, max_bias, m0, m1;
+        uint32_t n_head_log2;
         int32_t mask;
     } pushConsts {
         safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
         ne00, ne01, ne02,
-        scale,
+        scale, max_bias, m0, m1,
+        n_head_log2,
         bool(inB)
     };
 
@@ -911,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
     int32_t ne00, int32_t ne01, int32_t ne02,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02,
+    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
     int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    uint32_t nb10, uint32_t nb11, uint32_t nb12,
+    uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
     int32_t ne0, int32_t ne1,
     uint32_t r2, uint32_t r3
 ) {
@@ -923,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
         int32_t ne00, ne01, ne02;
-        uint32_t nb00, nb01, nb02;
+        uint32_t nb00, nb01, nb02, nb03;
         int32_t ne10, ne11, ne12;
-        uint32_t nb10, nb11, nb12;
+        uint32_t nb10, nb11, nb12, nb13;
         int32_t ne0, ne1;
         uint32_t r2, r3;
     } pushConsts {
         safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
         ne00, ne01, ne02,
-        nb00, nb01, nb02,
+        nb00, nb01, nb02, nb03,
         ne10, ne11, ne12,
-        nb10, nb11, nb12,
+        nb10, nb11, nb12, nb13,
         ne0, ne1,
         r2, r3
     };
@@ -1013,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
     int32_t ne00, int32_t ne01, int32_t ne02,
     int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
     int32_t ne0, int32_t ne1,
+    uint32_t nb01, uint32_t nb02, uint32_t nb03,
+    uint32_t nb11, uint32_t nb12, uint32_t nb13,
     uint32_t r2, uint32_t r3
 ) {
     struct PushConstants {
@@ -1020,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
         int32_t ne00, ne01, ne02;
         int32_t ne10, ne12;
         int32_t ne0, ne1;
+        uint32_t nb01, nb02, nb03;
+        uint32_t nb11, nb12, nb13;
         uint32_t r2, r3;
     } pushConsts {
         safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
         ne00, ne01, ne02,
         ne10, ne12,
         ne0, ne1,
+        nb01, nb02, nb03,
+        nb11, nb12, nb13,
         r2, r3
     };
 
     auto name = std::string(__func__) + "_" + suffix;
     std::shared_ptr<kp::Algorithm> s_algo = nullptr;
     if (!komputeManager()->hasAlgorithm(name)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+        const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
         s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
     } else {
         s_algo = komputeManager()->getAlgorithm(name);
@@ -1074,19 +1085,26 @@ static void ggml_vk_mul_mat_q4_k(
     const std::shared_ptr<kp::Tensor>& inB,
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
-    int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
-    int32_t ne1, int32_t r2, int32_t r3
+    int32_t ne00, int32_t ne01, int32_t ne02,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    int32_t ne0, int32_t ne1,
+    uint32_t nb01, uint32_t nb02, uint32_t nb03,
+    uint32_t nb11, uint32_t nb12, uint32_t nb13,
+    uint32_t r2, uint32_t r3
 ) {
     const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
         kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
 
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
+        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
+        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
+        uint32_t r2, r3;
     } pushConsts {
-        0, 0, 0,
-        ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
+        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
+        nb01, nb02, nb03, nb11, nb12, nb13,
+        r2, r3
     };
 
     std::shared_ptr<kp::Algorithm> s_algo = nullptr;
@@ -1108,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
     const std::shared_ptr<kp::Tensor>& inB,
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
-    int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
+    int32_t ne00, int32_t ne01, int32_t ne02,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    int32_t ne0, int32_t ne1,
+    uint32_t nb01, uint32_t nb02, uint32_t nb03,
+    uint32_t nb11, uint32_t nb12, uint32_t nb13,
+    uint32_t r2, uint32_t r3
 ) {
     const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
         kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
 
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, gqa;
+        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
+        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
+        uint32_t r2, r3;
     } pushConsts {
         inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne10, ne0, ne1, ne01, ne12/ne02
+        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
+        nb01, nb02, nb03, nb11, nb12, nb13,
+        r2, r3
     };
 
     std::shared_ptr<kp::Algorithm> s_algo = nullptr;
     if (!komputeManager()->hasAlgorithm(__func__)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
+        const uint32_t local_x = 2;
+        const uint32_t local_y = ggml_vk_current_device().subgroupSize;
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
     } else {
         s_algo = komputeManager()->getAlgorithm(__func__);
         s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
+        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
         s_algo->setPushConstants<PushConstants>({pushConsts});
         s_algo->updateDescriptors(s_kompute_context->pool.get());
     }
@@ -1217,10 +1244,11 @@ static void ggml_vk_rope(
     kp::Sequence& seq,
     const std::shared_ptr<kp::Tensor>& inA,
     const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& inC,
     const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
     ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
-    float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
+    float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
     int32_t ne01, int32_t ne02, int32_t ne03,
     uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
     int32_t ne0,
@@ -1228,11 +1256,17 @@ static void ggml_vk_rope(
 ) {
     GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
 
-    static const auto spirv_f16 = getSpirvShader(
-        kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
+    static const auto spirv_norm_f16 = getSpirvShader(
+        kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
+    );
+    static const auto spirv_norm_f32 = getSpirvShader(
+        kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
+    );
+    static const auto spirv_neox_f16 = getSpirvShader(
+        kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
     );
-    static const auto spirv_f32 = getSpirvShader(
-        kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
+    static const auto spirv_neox_f32 = getSpirvShader(
+        kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
     );
 
     int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@@ -1247,32 +1281,40 @@ static void ggml_vk_rope(
     GGML_ASSERT(nb0  % type_size == 0);
 
     struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
+        uint32_t inAOff, inBOff, inCOff, outOff;
         int32_t n_dims, mode, n_ctx_orig;
-        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+        float freq_base, freq_scale;
+        bool has_freq_factors;
+        float ext_factor, attn_factor, beta_fast, beta_slow;
         uint32_t nb00, nb01, nb02, nb03;
         int32_t ne0;
         uint32_t nb0, nb1, nb2, nb3;
     } pushConsts {
-        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
+        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
         n_dims, mode, n_ctx_orig,
-        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+        freq_base, freq_scale,
+        has_freq_factors,
+        ext_factor, attn_factor, beta_fast, beta_slow,
         nb00, nb01, nb02, nb03,
         ne0,
         nb0, nb1, nb2, nb3
     };
 
-    auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
+    auto & inC_ = inC ? inC : inA;
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_f16 = src0t == GGML_TYPE_F16;
+
+    auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
     std::shared_ptr<kp::Algorithm> s_algo = nullptr;
     if (!komputeManager()->hasAlgorithm(name)) {
+        auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
         s_algo = komputeManager()->algorithm<float, PushConstants>(
-            name, s_kompute_context->pool.get(), {inA, inB, out},
-            src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
+            name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
             {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
         );
     } else {
         s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
+        s_algo->setTensors({inA, inB, inC_, out});
         s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
         s_algo->setPushConstants<PushConstants>({pushConsts});
         s_algo->updateDescriptors(s_kompute_context->pool.get());
@@ -1351,11 +1393,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
 }
 
 static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    int64_t n = ggml_nelements(op);
     switch (op->op) {
         case GGML_OP_UNARY:
+            if (n % 4 != 0) return false;
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_GELU:
+                    if (n % 8 != 0) return false;
+                    // fall through
+                case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_SILU:
                     return ggml_is_contiguous(op->src[0]);
                 default:
@@ -1413,8 +1459,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
 
             switch (op->src[0]->type) {
                 case GGML_TYPE_F32:
-                case GGML_TYPE_Q6_K:
                     return op->ne[3] == 1;
+                case GGML_TYPE_Q6_K:
                 case GGML_TYPE_F16:
                 case GGML_TYPE_Q8_0:
                 case GGML_TYPE_Q4_0:
@@ -1515,9 +1561,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
             const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
             uint32_t off_src0 = 0;
             uint32_t off_src1 = 0;
+            uint32_t off_src2 = 0;
             uint32_t off_dst  = 0;
             const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
             const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
+            const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
             const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
 
             switch (dst->op) {
@@ -1593,11 +1641,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
                         GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
 
-#pragma message("TODO: add ALiBi support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
-                        GGML_ASSERT(max_bias == 0.0f);
+                        const int64_t nrows_x = ggml_nrows(src0);
+                        const int64_t nrows_y = src0->ne[1];
+
+                        const uint32_t n_head      = nrows_x/nrows_y;
+                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
-                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
+                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
                     } break;
                 case GGML_OP_DIAG_MASK_INF:
                     {
@@ -1649,38 +1702,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                             case GGML_TYPE_F16:
                                 ggml_vk_mul_mat_f16(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
+                                    ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+                                    ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
                                     ne0, ne1, r2, r3
                                 );
                                 break;
                             case GGML_TYPE_Q8_0:
                                 ggml_vk_mul_mat_q8_0(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
+                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
                                 );
                                 break;
                             case GGML_TYPE_Q4_0:
                                 ggml_vk_mul_mat_q4_0(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
+                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
                                 );
                                 break;
                             case GGML_TYPE_Q4_1:
                                 ggml_vk_mul_mat_q4_1(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
+                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
                                 );
                                 break;
                             case GGML_TYPE_Q4_K:
                                 ggml_vk_mul_mat_q4_k(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
+                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
                                 );
                                 break;
                             case GGML_TYPE_Q6_K:
                                 ggml_vk_mul_mat_q6_k(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
+                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
                                 );
                                 break;
                             default: {
@@ -1709,13 +1768,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     } break;
                 case GGML_OP_ROPE:
                     {
-#pragma message("TODO: implement phi3 frequency factors support")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
-                        GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
-
-#pragma message("TODO: update rope NORM mode to match NEOX mode")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
-
                         GGML_ASSERT(ne10 == ne02);
                         GGML_ASSERT(src0t == dstt);
                         // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -1724,6 +1776,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                         // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
                         const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
+                        const bool has_freq_factors = dst->src[2] != nullptr;
+
                         float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                         memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
                         memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
@@ -1732,8 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                         memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
                         ggml_vk_rope(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
-                            freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+                            seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
+                            freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
                             ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
                         );
                     } break;
index 2aaddf704a758b3fb54ba1a99db25f5b22e7925f..dbe4cf804e6c0d1723c7c57f70ba135ab6b91bf8 100644 (file)
@@ -3,6 +3,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_float16: require
 #extension GL_EXT_shader_explicit_arithmetic_types_int8: require
 #extension GL_EXT_shader_explicit_arithmetic_types_int16: require
+#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
 #extension GL_EXT_control_flow_attributes: enable
 #extension GL_KHR_shader_subgroup_arithmetic : require
 #extension GL_EXT_debug_printf : enable
index 8f0a9031f7a377db9465013b686acfa604b2d2c3..0ab1b2fc20eebfc86465083da209bf71fde880c8 100644 (file)
@@ -20,12 +20,14 @@ layout (push_constant) uniform parameter {
     uint nb00;
     uint nb01;
     uint nb02;
+    uint nb03;
     int ne10;
     int ne11;
     int ne12;
     uint nb10;
     uint nb11;
     uint nb12;
+    uint nb13;
     int ne0;
     int ne1;
     uint r2;
@@ -42,7 +44,7 @@ void main() {
     const uint i12 = im%pcs.ne12;
     const uint i13 = im/pcs.ne12;
 
-    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
+    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
 
     const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
 
@@ -52,7 +54,7 @@ void main() {
             break;
         }
 
-        const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
+        const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 
         float sumf = 0;
         for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
index fc8e45aa9777610f125aee58b9bf70dca73c3fc2..a5752a3a0065f54910fbb05442a4ca4b36c43132 100644 (file)
@@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
     int ne01;
     int ne02;
     int ne12;
-    int r2;
-    int r3;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
+    uint r2;
+    uint r3;
 } pcs;
 
 void main() {
@@ -50,10 +56,11 @@ void main() {
     const uint i12 = im%pcs.ne12;
     const uint i13 = im/pcs.ne12;
 
-    const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+    const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
+    const uint offset1 =        r1*pcs.nb11 + (i12       )*pcs.nb12 + (i13       )*pcs.nb13;
 
-    const uint xblk = ib_row + offset0 + pcs.inAOff;
-    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
+    const uint xblk = offset0 + pcs.inAOff;
+    const uint y = (offset1 / 4) + pcs.inBOff;
 
     float yl[16];
     float yh[16];
@@ -74,7 +81,7 @@ void main() {
         }
 
         for (int row = 0; row < N_DST; row++) {
-            uint row_idx = row * nb;
+            uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
 
             uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
             uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
index c9baebdf4baac6a0006d869449a85bf522ee6590..d331d1a70572ee1f6bd4b42ce2be53f5adc5b7b3 100644 (file)
@@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
     int ne0;
     int ne1;
     int ne01;
-    int gqa;
+    int ne02;
+    int ne12;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
+    uint r2;
+    uint r3;
 } pcs;
 
 void main() {
@@ -34,12 +43,15 @@ void main() {
 
     const uint r0 = gl_WorkGroupID.x;
     const uint r1 = gl_WorkGroupID.y;
-    const uint r2 = gl_WorkGroupID.z;
+    const uint im = gl_WorkGroupID.z;
 
     const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
-    const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
-    const uint x = row * nb + offset0; // Based from inA without base offset
-    const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+
+    const uint i12 = im%pcs.ne12;
+    const uint i13 = im/pcs.ne12;
+
+    const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
+    const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 
     float sumf = 0;
 
@@ -89,6 +101,6 @@ void main() {
 
     const float tot = subgroupAdd(sumf);
     if (subgroupElect()) {
-        out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
+        out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
     }
 }
index 440b5ab2c81f887375587da58d0a1d946f2081ed..a6517cc1f1993cb455ed6075a97701c3a7c223df 100644 (file)
@@ -14,10 +14,15 @@ void main() {
     const uint i12 = im%pcs.ne12;
     const uint i13 = im/pcs.ne12;
 
-    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+    // pointers to src0 rows
+    uint ax[N_ROWS];
+    for (int row = 0; row < N_ROWS; ++row) {
+        const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
+
+        ax[row] = offset0 + pcs.inAOff;
+    }
 
-    const uint x = offset0; // Based from inA without base offset
-    const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+    const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 
     float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
 
@@ -32,8 +37,7 @@ void main() {
 
     for (uint ib = ix; ib < nb; ib += 16) {
         for (int row = 0; row < N_ROWS; row++) {
-            const uint block_index = x + ib + row * nb;
-            sumf[row] += block_q_n_dot_y(block_index, yb, il);
+            sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
         }
 
         yb += BLOCKS_IN_QUANT * 16;
index 7912b09ac69c42b0a636263a8212d14d5d6b95c7..a9a2f22180ffd45b408ce53ab4d6d6a4585fb03f 100644 (file)
@@ -1,5 +1,5 @@
 layout(local_size_x_id = 0) in;
-layout(local_size_y = 1) in;
+layout(local_size_y = 8) in;
 layout(local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
@@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
     int  ne12;
     int  ne0;
     int  ne1;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
     uint r2;
     uint r3;
 } pcs;
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp
deleted file mode 100644 (file)
index 0ecfb2e..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    const int p = inB[pcs.inBOff + i2];
-
-    float theta = float(p);
-
-    if (!is_neox) {
-        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+1]);
-
-            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
-        }
-    } else {
-        const float inv_ndims = -1.f/pcs.n_dims;
-        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
-            const uint cur_rot = ic;
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint i0 = ic/2;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+pcs.n_dims/2]);
-
-            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
-        }
-
-        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
-            const uint i0 = ic;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            out_[dst_data + 0] = inA[src + 0];
-            out_[dst_data + 1] = inA[src + 1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp
deleted file mode 100644 (file)
index cec0fd9..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    const int p = inB[pcs.inBOff + i2];
-
-    float theta = float(p);
-
-    if (!is_neox) {
-        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+1];
-
-            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
-        }
-    } else {
-        const float inv_ndims = -1.f/pcs.n_dims;
-        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
-            const uint cur_rot = ic;
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint i0 = ic/2;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+pcs.n_dims/2];
-
-            out_[dst_data] = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
-        }
-
-        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
-            const uint i0 = ic;
-
-            const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
-
-            out_[dst_data + 0] = inA[src + 0];
-            out_[dst_data + 1] = inA[src + 1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp
new file mode 100644 (file)
index 0000000..63659cb
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
+layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
+layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    float theta_base = float(inB[pcs.inBOff + i2]);
+    float inv_ndims = -1.f/pcs.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
+        if (i0 < pcs.n_dims) {
+            uint ic = i0/2;
+
+            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
+
+            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            const float x0 = float(inA[src]);
+            const float x1 = float(inA[src+pcs.n_dims/2]);
+
+            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
+            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
+        } else {
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            out_[dst_data]   = inA[src];
+            out_[dst_data+1] = inA[src+1];
+        }
+    }
+}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp
new file mode 100644 (file)
index 0000000..4df5620
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
+layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
+layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    float theta_base = float(inB[pcs.inBOff + i2]);
+    float inv_ndims = -1.f/pcs.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
+        if (i0 < pcs.n_dims) {
+            uint ic = i0/2;
+
+            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
+
+            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            const float x0 = inA[src];
+            const float x1 = inA[src+pcs.n_dims/2];
+
+            out_[dst_data]              = x0*cos_theta - x1*sin_theta;
+            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            out_[dst_data]   = inA[src];
+            out_[dst_data+1] = inA[src+1];
+        }
+    }
+}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp
new file mode 100644 (file)
index 0000000..a3c0eda
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
+layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
+layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    float theta_base = float(inB[pcs.inBOff + i2]);
+    float inv_ndims = -1.f/pcs.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
+        if (i0 < pcs.n_dims) {
+            uint ic = i0/2;
+
+            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
+
+            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            const float x0 = float(inA[src]);
+            const float x1 = float(inA[src+1]);
+
+            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
+            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
+        } else {
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            out_[dst_data]   = inA[src];
+            out_[dst_data+1] = inA[src+1];
+        }
+    }
+}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp
new file mode 100644 (file)
index 0000000..b7963ae
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
+layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
+layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    float theta_base = float(inB[pcs.inBOff + i2]);
+    float inv_ndims = -1.f/pcs.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
+        if (i0 < pcs.n_dims) {
+            uint ic = i0/2;
+
+            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
+
+            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            const float x0 = inA[src];
+            const float x1 = inA[src+1];
+
+            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
+            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            out_[dst_data]   = inA[src];
+            out_[dst_data+1] = inA[src+1];
+        }
+    }
+}
index 7bc9176cabaae4f45c37adc2356bb57f1c96c176..4165295bf4b3c2ba61fcc97431e1f67bb2452b61 100644 (file)
@@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
     int ne01;
     int ne02;
     float scale;
+    float max_bias;
+    float m0;
+    float m1;
+    uint n_head_log2;
     int mask;
 } pcs;
 
@@ -34,17 +38,29 @@ void main() {
     const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
     const uint pdst = extra_off + pcs.outOff; // Based from out_
 
+    float slope = 1.0f;
+
+    // ALiBi
+    if (pcs.max_bias > 0.0f) {
+        int64_t h = i02;
+
+        float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
+        int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
+
+        slope = pow(base, float(exp));
+    }
+
     // parallel max
     float localMax = uintBitsToFloat(0xFF800000);
     for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
+        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
     }
     float max_ = subgroupMax(localMax);
 
     // parallel sum
     float localSum = 0.0f;
     for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
+        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
         localSum += exp_psrc0;
         out_[pdst + i00] = exp_psrc0;
     }
index df4702896d46f2f5e42a3328d18e6212a0f3d9d2..0fca640dcc232d5d4699e12116b317091d96dbae 100644 (file)
@@ -8,12 +8,14 @@ layout(local_size_x = 1) in;
 layout (push_constant) uniform parameter {
     uint inAOff;
     uint inBOff;
+    uint inCOff;
     uint outOff;
     int n_dims;
     int mode;
     int n_ctx_orig;
     float freq_base;
     float freq_scale;
+    bool has_freq_factors;
     float ext_factor;
     float attn_factor;
     float beta_fast;