]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add freq factors (llama/7495)
authorAidanBeltonS <redacted>
Mon, 27 May 2024 12:34:09 +0000 (13:34 +0100)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
src/ggml-sycl.cpp

index 496ec61c3c28a6789b20838990e1e985c0a519c1..f329bc27265feab2cd2a1b0b806bb35e050286a9 100644 (file)
@@ -8830,12 +8830,11 @@ static void rope(
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_pos>
+template<typename T, bool has_pos, bool has_freq_facs>
 static void rope_neox(
     const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
-,
-    const sycl::nd_item<3> &item_ct1) {
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
+    const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
     const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
                          item_ct1.get_local_id(1));
 
@@ -8863,8 +8862,10 @@ static void rope_neox(
     float cur_rot = inv_ndims * ic - ib;
 
     const int p = has_pos ? pos[i2] : 0;
+    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+
     const float theta_base =
-        p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
+        p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -12413,7 +12414,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
                            const int32_t *pos, float freq_scale,
                            int p_delta_rows, float freq_base, float ext_factor,
                            float attn_factor, rope_corr_dims corr_dims,
-                           dpct::queue_ptr stream) {
+                           const float * freq_factors, dpct::queue_ptr stream) {
     GGML_ASSERT(ncols % 2 == 0);
     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,38 +12424,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
     const float inv_ndims = -1.0f / n_dims;
 
     if (pos == nullptr) {
-        /*
-        DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
-                                    p_delta_rows, ext_factor, attn_factor,
-                                    corr_dims, theta_scale, inv_ndims,
-                                    item_ct1);
-            });
+        if (freq_factors == nullptr) {
+            stream->parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
+                                        p_delta_rows, ext_factor, attn_factor,
+                                        corr_dims, theta_scale, inv_ndims, freq_factors,
+                                        item_ct1);
+                });
+        } else {
+            stream->parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
+                                        p_delta_rows, ext_factor, attn_factor,
+                                        corr_dims, theta_scale, inv_ndims, freq_factors,
+                                        item_ct1);
+                });
+        }
     } else {
-        /*
-        DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
 
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
-                                   p_delta_rows, ext_factor, attn_factor,
-                                   corr_dims, theta_scale, inv_ndims, item_ct1);
-            });
+        if (freq_factors == nullptr) {
+            stream->parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
+                                       p_delta_rows, ext_factor, attn_factor,
+                                       corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
+                });
+        } else {
+            stream->parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
+                                       p_delta_rows, ext_factor, attn_factor,
+                                       corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
+                });
+        }
     }
 }
 
@@ -13986,9 +13997,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
                               ggml_tensor *dst, const float *src0_dd,
                               const float *src1_dd, float *dst_dd,
                               const dpct::queue_ptr &main_stream) {
-#pragma message("TODO: implement phi3 frequency factors support")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
-    GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+    const ggml_tensor * src2 = dst->src[2];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
@@ -14014,6 +14023,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
+    const float * freq_factors = nullptr;
     const int32_t * pos = nullptr;
     if ((mode & 1) == 0) {
         GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14024,6 +14034,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    if (is_neox) {
+        pos = (const int32_t *) src1_dd;
+
+        if (src2 != nullptr) {
+            freq_factors = (const float *) src2->data;
+        }
+    } else {
+        GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+    }
+
     rope_corr_dims corr_dims;
     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
 
@@ -14035,13 +14055,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_sycl(
                 (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, main_stream
+                attn_factor, corr_dims, freq_factors, main_stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
                            ne00, n_dims, nrows, pos, freq_scale, ne01,
                            freq_base, ext_factor, attn_factor, corr_dims,
-                           main_stream);
+                           freq_factors, main_stream);
         } else {
             GGML_ASSERT(false);
         }