]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
RoPE: fix back, CUDA support for back + noncont. (llama/11240)
authorJohannes Gäßler <redacted>
Wed, 15 Jan 2025 11:51:37 +0000 (12:51 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
* RoPE: fix back, CUDA support for back + noncont.

* fix comments reg. non-cont. RoPE support [no-ci]

ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/rope.cu
ggml/src/ggml-cuda/rope.cuh
ggml/src/ggml.c

index 8f8cb9e1aa1401c536db42620d6809be5b55b7c8..a9c051cd5d691586fa6a443bb6c656dcbd9b9e72 100644 (file)
@@ -1500,7 +1500,7 @@ extern "C" {
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
-    GGML_API struct ggml_tensor * ggml_rope_back(
+    GGML_API struct ggml_tensor * ggml_rope_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a, // gradients of ggml_rope result
             struct ggml_tensor  * b, // positions
@@ -1515,6 +1515,23 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
+    GGML_API struct ggml_tensor * ggml_rope_multi_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[4],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+
     // clamp
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_clamp(
index 424c9781f44277b80460de972e52fa70b003cdc5..8bf5f781a599ef2f6b88eba42f495d696ef29ce7 100644 (file)
@@ -13668,6 +13668,7 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_SOFT_MAX:
                 case GGML_OP_ROPE:
+                case GGML_OP_ROPE_BACK:
                     {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     } break;
index f11399cc628cadd77cf06d44c034e7e49b961a2e..5c47ceb7314577abe8e8563bf4ba889329416adc 100644 (file)
@@ -403,8 +403,6 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
-        case GGML_OP_ROPE_BACK:
-            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
index 1dac397c4b0836b2111f0c2f910c7edfff695d53..9118edc7282268c6ba7abbece154eab99019241a 100644 (file)
@@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ROPE:
             ggml_cuda_op_rope(ctx, dst);
             break;
+        case GGML_OP_ROPE_BACK:
+            ggml_cuda_op_rope_back(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SOFT_MAX:
             return true;
         case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ROPE_BACK: {
+            const size_t ts = ggml_type_size(op->src[0]->type);
+            const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
+            return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
+        }
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
@@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
             return op->ne[1];
         case GGML_OP_MUL_MAT_ID:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
             return op->ne[2];
         default:
             return ggml_nrows(op);
index 2c84778d29c9be2760d21c37cfd2d46975857c52..e1912fee1f9abd35657f2b1530939903be34447d 100644 (file)
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 
 // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+template<bool forward>
 static __device__ void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
+        const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
+        float mscale, float & cos_theta, float & sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
         // Get n-d magnitude scaling corrected for interpolation
         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
     }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
+    cos_theta = cosf(theta) * mscale;
+    sin_theta = sinf(theta) * mscale;
+    if (!forward) {
+        sin_theta *= -1.0f;
+    }
 }
 
-template<typename T, bool has_ff>
+template<bool forward, bool has_ff, typename T>
 static __global__ void rope_norm(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
         return;
     }
 
-    const int i  = row*ne0 + i0;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
+
+    const int idst = row_dst*ne0 + i0;
+    const int ix   = channel_x*s2 + row_x*s1 + i0;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + 1];
 
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0] = x0*cos_theta - x1*sin_theta;
+    dst[idst + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_ff>
+template<bool forward, bool has_ff, typename T>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
+
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_ff>
+template<bool forward, bool has_ff, typename T>
 static __global__ void rope_multi(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
+        const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
-        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
-        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_ff>
+template<bool forward, bool has_ff, typename T>
 static __global__ void rope_vision(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
+        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+        const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows; // i2-th tokens
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
-    int sect_dims = sections.v[0] + sections.v[1];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int sect_dims = sections.v[0] + sections.v[1];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base = pos[i2]*powf(theta_scale, p);
+        theta_base = pos[channel_x]*powf(theta_scale, p);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+        theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims];
 
-    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]      = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T>
+template<bool forward, typename T>
 static void rope_norm_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template<typename T>
+template<bool forward, typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template<typename T>
+template<bool forward, typename T>
 static void rope_multi_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-template<typename T>
+template<bool forward, typename T>
 static void rope_vision_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-static void rope_norm_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_norm_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
-) {
-
-    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_multi_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_multi_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template <bool forward>
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * src2 = dst->src[2];
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ne02 = src0->ne[2]; // num heads
     const int64_t nr = ggml_nrows(src0);
 
+    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     // compute
     if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda<forward>(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda<forward>(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_mrope && !is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_multi_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda<forward>(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_multi_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda<forward>(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_vision_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda<forward>(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_vision_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda<forward>(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_norm_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda<forward>(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_norm_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda<forward>(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     }
 }
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl<true>(ctx, dst);
+}
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl<false>(ctx, dst);
+}
index 0f787a0b2f7cd0613c4073dc232925be5d7fea3d..9139f3b220df7756c0dd984a5250f1bffa83d9d2 100644 (file)
@@ -3,3 +3,5 @@
 #define CUDA_ROPE_BLOCK_SIZE 256
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index ae8147dd7d2f7af7650c071389624351aa0c95d4..0a0dcc7e3b1da600a6610dff15ae0171882d14d1 100644 (file)
@@ -3699,7 +3699,7 @@ void ggml_rope_yarn_corr_dims(
 
 // ggml_rope_back
 
-struct ggml_tensor * ggml_rope_back(
+struct ggml_tensor * ggml_rope_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -3713,29 +3713,32 @@ struct ggml_tensor * ggml_rope_back(
         float                 attn_factor,
         float                 beta_fast,
         float                 beta_slow) {
-    GGML_ASSERT(ggml_is_vector(b));
-    GGML_ASSERT(b->type == GGML_TYPE_I32);
-    GGML_ASSERT(a->ne[2] == b->ne[0]);
-
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-
-    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
-    memcpy(params +  5, &freq_base,    sizeof(float));
-    memcpy(params +  6, &freq_scale,   sizeof(float));
-    memcpy(params +  7, &ext_factor,   sizeof(float));
-    memcpy(params +  8, &attn_factor,  sizeof(float));
-    memcpy(params +  9, &beta_fast,    sizeof(float));
-    memcpy(params + 10, &beta_slow,    sizeof(float));
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op     = GGML_OP_ROPE_BACK;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
+    struct ggml_tensor * result = ggml_rope_ext(
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
     return result;
 }
 
+struct ggml_tensor * ggml_rope_multi_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   sections[4],
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    struct ggml_tensor * result = ggml_rope_multi(
+        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
+    return result;
+}
 // ggml_clamp
 
 struct ggml_tensor * ggml_clamp(
@@ -5598,6 +5601,7 @@ static void ggml_compute_backward(
                 //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
                 const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
                 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                int sections[4] = {0, 0, 0, 0};
 
                 memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
                 memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
@@ -5605,10 +5609,14 @@ static void ggml_compute_backward(
                 memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
                 memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
                 memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
-
-                ggml_add_or_set(ctx, cgraph, isrc0,
-                    ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
-                        freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+                memcpy(&sections,                    tensor->op_params + 11, sizeof(sections));
+
+                struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
+                    ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
+                    ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+                ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
             }
             GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
         } break;