]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llama : add Qwen2VL support + multimodal RoPE (llama/10361)
authorHimariO <redacted>
Sat, 14 Dec 2024 12:43:46 +0000 (20:43 +0800)
committerGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 10:52:16 +0000 (12:52 +0200)
* Barebone Qwen2VL LLM convertor

* Add Qwen2VL cli entrypoint

* [WIP] add qwen2vl arch

* Verify m-rope output

* Add vl-rope/2d-rope support for qwen2vl ViT

* update qwen2vl cli tool

* update 5D tensor op workaround

* [WIP] qwen2vl vision model

* make batch and clip utils compatible with qwen2vl

* [WIP] create inference workflow, gguf convert script but fix

* correcting vision-rope behavior, add the missing last layer back to ViT

* add arg parser to qwen2vl_surgery

* replace variable size array with vector

* cuda-gdb cmake preset

* add fp32 mrope, vision rope kernel

* add fp16 support for qwen2vl and m-rope

* add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION`

* fix rope op mode switching, out dated func args

* update `llama_hparams`

* update to keep up stream changes

* resolve linter, test errors

* add makefile entry, update speical image padding token

* add mrope unit test, fix few compiler warnings

* rename `mrope` related function, params

* minor updates on debug util, bug fixs

* add `m-rope` testcase to `test-backend-ops`

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* fix traililng whitespce

* store `llama_hparams.rope_sections` with fixed size array

* update position id tensor size check in GGML_OP_ROPE

* minor updates

* update `ggml_backend_*_supports_op` of unsupported backends

* remote old `rope_section` compare operator

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/include/ggml.h
ggml/src/ggml-cann/ggml-cann.cpp
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cuda/rope.cu
ggml/src/ggml-kompute/ggml-kompute.cpp
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c

index 386d5a15d81a169943694d59557a85723b5c8a0e..b0c1ac9ce2b89629ab705aaed7adbae71cf7d898 100644 (file)
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 
-#define GGML_ROPE_TYPE_NEOX 2
+#define GGML_ROPE_TYPE_NEOX   2
+#define GGML_ROPE_TYPE_MROPE  8
+#define GGML_ROPE_TYPE_VISION 24
 
 #define GGUF_MAGIC "GGUF"
 
@@ -1443,6 +1445,22 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
+    GGML_API struct ggml_tensor * ggml_rope_multi(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[4],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
             struct ggml_context * ctx,
index fa04ab84f3f150f99f7dca8ce7e1f4e0bc5b640e..d410c02445c27cce30d1a3dbb6293649fa0d40a7 100644 (file)
@@ -1747,6 +1747,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             if (*ext_factor != 0) {
                 return false;
             }
+
+            const int mode = ((const int32_t *) op->op_params)[2];
+            if (mode & GGML_ROPE_TYPE_MROPE) {
+                return false;
+            }
+            if (mode & GGML_ROPE_TYPE_VISION) {
+                return false;
+            }
+
             return true;
         }
         case GGML_OP_UPSCALE: {
index 92df6fdda18f22c79b1660dd3705fb18cb56b9b7..67e67a08968a1e2bebe15e77bb033d74da58ac54 100644 (file)
@@ -9133,6 +9133,64 @@ static void ggml_rope_cache_init(
     }
 }
 
+static void ggml_mrope_cache_init(
+     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta_t = theta_base_t;
+    float theta_h = theta_base_h;
+    float theta_w = theta_base_w;
+    float theta_e = theta_base_e;  // extra position id for vision encoder
+    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+    int sec_w = sections[1] + sections[0];
+    int sec_e = sections[2] + sec_w;
+    GGML_ASSERT(sect_dims <= ne0);
+
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+
+        int sector = (i0 / 2) % sect_dims;
+        if (indep_sects) {
+            // compute theta independently for each dim sections
+            // (i.e. reset corresponding theta when `i0` go from one section to another)
+            if (sector == 0) {
+                theta_t = theta_base_t;
+            }
+            else if (sector == sections[0]) {
+                theta_h = theta_base_h;;
+            }
+            else if (sector == sec_w) {
+                theta_w = theta_base_w;
+            }
+            else if (sector == sec_e) {
+                theta_e = theta_base_e;
+            }
+        }
+
+        float theta = theta_t;
+        if (sector >= sections[0] && sector < sec_w) {
+            theta = theta_h;
+        }
+        else if (sector >= sec_w && sector < sec_w + sections[2]) {
+            theta = theta_w;
+        }
+        else if (sector >= sec_w + sections[2]) {
+            theta = theta_e;
+        }
+
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta_t *= theta_scale;
+        theta_w *= theta_scale;
+        theta_h *= theta_scale;
+        theta_e *= theta_scale;
+    }
+}
+
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst,
@@ -9143,6 +9201,7 @@ static void ggml_compute_forward_rope_f32(
     const struct ggml_tensor * src2 = dst->src[2];
 
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
@@ -9156,6 +9215,7 @@ static void ggml_compute_forward_rope_f32(
     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -9188,6 +9248,16 @@ static void ggml_compute_forward_rope_f32(
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
 
     const float * freq_factors = NULL;
     if (src2 != NULL) {
@@ -9203,18 +9273,63 @@ static void ggml_compute_forward_rope_f32(
 
     const int32_t * pos = (const int32_t *) src1->data;
 
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
+    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
 
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
+            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                if (!is_neox) {
+                if (is_neox || is_mrope) {
+                    if (is_vision){
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims];
+
+                            dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
+
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
+                    }
+                } else {
                     for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
@@ -9228,8 +9343,10 @@ static void ggml_compute_forward_rope_f32(
                         dst_data[0] = x0*cos_theta - x1*sin_theta;
                         dst_data[1] = x0*sin_theta + x1*cos_theta;
                     }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
                         const int64_t ic = i0/2;
 
                         const float cos_theta = cache[i0 + 0];
@@ -9239,19 +9356,20 @@ static void ggml_compute_forward_rope_f32(
                         float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
                         const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
+                        const float x1 = src[n_dims];
 
-                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
                     }
-                }
-
-                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                    const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                    float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                } else {
+                    // fill the remain channels with data from src tensor
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                    dst_data[0] = src[0];
-                    dst_data[1] = src[1];
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
                 }
             }
         }
@@ -9269,6 +9387,7 @@ static void ggml_compute_forward_rope_f16(
     const struct ggml_tensor * src2 = dst->src[2];
 
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
@@ -9281,6 +9400,8 @@ static void ggml_compute_forward_rope_f16(
     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -9313,6 +9434,16 @@ static void ggml_compute_forward_rope_f16(
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
 
     const float * freq_factors = NULL;
     if (src2 != NULL) {
@@ -9330,16 +9461,61 @@ static void ggml_compute_forward_rope_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
 
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                if (!is_neox) {
+                if (is_neox || is_mrope) {
+                    if (is_vision) {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+                            dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    }
+                } else {
                     for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
@@ -9353,8 +9529,10 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
                         const int64_t ic = i0/2;
 
                         const float cos_theta = cache[i0 + 0];
@@ -9364,19 +9542,19 @@ static void ggml_compute_forward_rope_f16(
                         ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
                         const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
 
-                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
-                }
-
-                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                    ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                } else {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                    dst_data[0] = src[0];
-                    dst_data[1] = src[1];
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
                 }
             }
         }
index 88f586d689cfd39a9b2ebf8f482c37d1501dd771..2c84778d29c9be2760d21c37cfd2d46975857c52 100644 (file)
@@ -4,6 +4,11 @@ struct rope_corr_dims {
     float v[2];
 };
 
+
+struct mrope_sections {
+    int v[4];
+};
+
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -108,6 +113,105 @@ static __global__ void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
+template<typename T, bool has_ff>
+static __global__ void rope_multi(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows;
+
+    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    int sec_w = sections.v[1] + sections.v[0];
+    int sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < sections.v[0]) {
+        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sections.v[0] && sector < sec_w) {
+        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w + sections.v[2]) {
+        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+    }
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
+
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_vision(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows; // i2-th tokens
+
+    int sect_dims = sections.v[0] + sections.v[1];
+    int sec_w = sections.v[1] + sections.v[0];
+    int sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < sections.v[0]) {
+        const int p = sector;
+        theta_base = pos[i2]*powf(theta_scale, p);
+    }
+    else if (sector >= sections.v[0] && sector < sec_w) {
+        const int p = sector - sections.v[0];
+        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+    }
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims];
+
+    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+}
+
 template<typename T>
 static void rope_norm_cuda(
     const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -156,6 +260,56 @@ static void rope_neox_cuda(
     }
 }
 
+template<typename T>
+static void rope_multi_cuda(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    } else {
+        rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    }
+}
+
+template<typename T>
+static void rope_vision_cuda(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+    // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
+    // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    } else {
+        rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    }
+}
+
 static void rope_norm_cuda_f16(
     const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +339,38 @@ static void rope_neox_cuda_f32(
     rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
+static void rope_multi_cuda_f16(
+    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_multi_cuda_f32(
+    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f16(
+    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f32(
+    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
@@ -201,8 +387,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
+    const int64_t ne00 = src0->ne[0]; // head dims
+    const int64_t ne01 = src0->ne[1]; // num heads
+    const int64_t ne02 = src0->ne[2]; // num heads
     const int64_t nr = ggml_nrows(src0);
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
@@ -210,6 +397,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int mode       = ((int32_t *) dst->op_params)[2];
     //const int n_ctx      = ((int32_t *) dst->op_params)[3];
     const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+    mrope_sections sections;
 
     // RoPE alteration for extended context
     float freq_base;
@@ -225,8 +413,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne00/2);
+    }
 
     const int32_t * pos = (const int32_t *) src1_d;
 
@@ -253,6 +452,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         } else {
             GGML_ABORT("fatal error");
         }
+    } else if (is_mrope && !is_vision) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_multi_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_multi_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    } else if (is_vision) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_vision_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_vision_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
     } else {
         if (src0->type == GGML_TYPE_F32) {
             rope_norm_cuda_f32(
index 28ceecfc40d666b391c8fff909059ccc4387c796..50579227183d3c80ec0805bce34b020b27ed4573 100644 (file)
@@ -1419,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_NORM:
-        case GGML_OP_ROPE:
             return true;
+        case GGML_OP_ROPE:
+            {
+                const int mode = ((const int32_t *) op->op_params)[2];
+                if (mode & GGML_ROPE_TYPE_MROPE) {
+                    return false;
+                }
+                if (mode & GGML_ROPE_TYPE_VISION) {
+                    return false;
+                }
+                return true;
+            }
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
index 34fe5778e4f090b84bef786a02e611ee6f840016..28f590f9216beed95e766a91046a4506d71f4b98 100644 (file)
@@ -1125,8 +1125,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
         case GGML_OP_ARGMAX:
         case GGML_OP_NORM:
-        case GGML_OP_ROPE:
             return true;
+        case GGML_OP_ROPE:
+            {
+                const int mode = ((const int32_t *) op->op_params)[2];
+                if (mode & GGML_ROPE_TYPE_MROPE) {
+                    return false;
+                }
+                if (mode & GGML_ROPE_TYPE_VISION) {
+                    return false;
+                }
+                return true;
+            }
         case GGML_OP_IM2COL:
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
@@ -3026,7 +3036,9 @@ static void ggml_metal_encode_node(
             } break;
         case GGML_OP_ROPE:
             {
-                GGML_ASSERT(ne10 == ne02);
+                // make sure we have one or more position id(ne10) per token(ne02)
+                GGML_ASSERT(ne10 % ne02 == 0);
+                GGML_ASSERT(ne10 >= ne02);
 
                 const int nth = MIN(1024, ne00);
 
index 6b9f0b0d9a1c8726f186d7b25a8be5d6a26b0dd4..84f1328e7cf1cfeb5006836cae75fa65ef5676e8 100644 (file)
@@ -4488,7 +4488,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SOFT_MAX:
             return true;
         case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
+            {
+                const int mode = ((const int32_t *) op->op_params)[2];
+                if (mode & GGML_ROPE_TYPE_MROPE) {
+                    return false;
+                }
+                if (mode & GGML_ROPE_TYPE_VISION) {
+                    return false;
+                }
+                return ggml_is_contiguous(op->src[0]);
+            }
         case GGML_OP_IM2COL:
             // TODO: add support for the new F32 operations
             return op->src[0]->type == GGML_TYPE_F16;
index 515d66b39342629d5c581f2c3ddebba026110abb..a26a8fe092312932e44cace60099ff256aac644f 100644 (file)
@@ -7687,7 +7687,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_REPEAT:
             return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
         case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
+            {
+                const int mode = ((const int32_t *) op->op_params)[2];
+                if (mode & GGML_ROPE_TYPE_MROPE) {
+                    return false;
+                }
+                if (mode & GGML_ROPE_TYPE_VISION) {
+                    return false;
+                }
+                return ggml_is_contiguous(op->src[0]);
+            }
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
index 54922ae22b9c6e3324153a40ea210ca44c9829f6..030d93a5177c0b7589795e0d89552e24a2a9b23c 100644 (file)
@@ -3517,15 +3517,18 @@ static struct ggml_tensor * ggml_rope_impl(
         GGML_ASSERT(c->ne[0] >= n_dims / 2);
     }
 
+    int sections[4] = {0, 0, 0, 0};
+
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+    int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
     memcpy(params +  5, &freq_base,    sizeof(float));
     memcpy(params +  6, &freq_scale,   sizeof(float));
     memcpy(params +  7, &ext_factor,   sizeof(float));
     memcpy(params +  8, &attn_factor,  sizeof(float));
     memcpy(params +  9, &beta_fast,    sizeof(float));
     memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &sections,     sizeof(int)*4);
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op     = GGML_OP_ROPE;
@@ -3547,6 +3550,53 @@ struct ggml_tensor * ggml_rope(
     );
 }
 
+struct ggml_tensor * ggml_rope_multi(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   sections[4],
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    // Multimodal Rotary Position Embedding
+    GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
+
+    if (c) {
+        GGML_ASSERT(c->type == GGML_TYPE_F32);
+        GGML_ASSERT(c->ne[0] >= n_dims / 2);
+    }
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(&params[11], sections,      sizeof(int)*4);
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op   = GGML_OP_ROPE;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_rope_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,