]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sync : ggml (CUDA GLM RoPE + POSIX) (#3082)
authorGeorgi Gerganov <redacted>
Fri, 8 Sep 2023 14:58:07 +0000 (17:58 +0300)
committerGitHub <redacted>
Fri, 8 Sep 2023 14:58:07 +0000 (17:58 +0300)
ggml-ci

CMakeLists.txt
ggml-cuda.cu
ggml.c

index 0abf1df7b83a1a8693579641899b37d80442fe53..e6242dc31e33868c9da4ac524180d486fc108f17 100644 (file)
@@ -551,6 +551,10 @@ else()
     message(STATUS "Unknown architecture")
 endif()
 
+#
+# POSIX conformance
+#
+
 # clock_gettime came in POSIX.1b (1993)
 # CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
 # posix_memalign came in POSIX.1-2001 / SUSv3
@@ -560,39 +564,39 @@ add_compile_definitions(_XOPEN_SOURCE=600)
 # Somehow in OpenBSD whenever POSIX conformance is specified
 # some string functions rely on locale_t availability,
 # which was introduced in POSIX.1-2008, forcing us to go higher
-IF (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
     remove_definitions(-D_XOPEN_SOURCE=600)
     add_compile_definitions(_XOPEN_SOURCE=700)
-ENDIF()
+endif()
 
 # Data types, macros and functions related to controlling CPU affinity and
 # some memory allocation are available on Linux through GNU extensions in libc
-IF (CMAKE_SYSTEM_NAME MATCHES "Linux")
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
     add_compile_definitions(_GNU_SOURCE)
-ENDIF()
+endif()
 
 # RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
 # and on macOS its availability depends on enabling Darwin extensions
 # similarly on DragonFly, enabling BSD extensions is necessary
-IF (CMAKE_SYSTEM_NAME MATCHES "Darwin")
+if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
     add_compile_definitions(_DARWIN_C_SOURCE)
-ENDIF()
-IF (CMAKE_SYSTEM_NAME MATCHES "DragonFly")
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "DragonFly")
     add_compile_definitions(_DARWIN_C_SOURCE)
-ENDIF()
+endif()
 
 # alloca is a non-standard interface that is not visible on BSDs when
 # POSIX conformance is specified, but not all of them provide a clean way
 # to enable it in such cases
-IF (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
+if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
     add_compile_definitions(__BSD_VISIBLE)
-ENDIF()
-IF (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
     add_compile_definitions(_NETBSD_SOURCE)
-ENDIF()
-IF (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
     add_compile_definitions(_BSD_SOURCE)
-ENDIF()
+endif()
 
 #
 # libraries
index d2dbf824ef2da9145161d1f43dbfedd7e107ce6e..00e9bbeae44498240ef0b5da1877694e24f36afd 100644 (file)
@@ -4086,7 +4086,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
+static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
+                                    const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
 
@@ -4098,8 +4099,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     const int i = row*ncols + col;
 
     const float col_theta_scale = powf(theta_scale, col);
+    const float p = p0 + p_delta*(row/p_delta_rows);
 
-    const float theta = p*col_theta_scale;
+    const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -4109,7 +4111,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     dst[i + 0]           = x0*cos_theta - x1*sin_theta;
     dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
 
-    const float block_theta = block_p*col_theta_scale;
+    const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
     const float sin_block_theta = sinf(block_theta);
     const float cos_block_theta = cosf(block_theta);
 
@@ -4984,12 +4986,13 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
     rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
-    GGML_ASSERT(nrows % 4 == 0);
-    const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
-    const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
+static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
+                              const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+    GGML_ASSERT(ncols % 4 == 0);
+    const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
+    const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
 }
 
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -5723,22 +5726,18 @@ inline void ggml_cuda_op_rope(
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
     // compute
     if (is_glm) {
-        const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
-        const float id_p = min(p, n_ctx - 2.f);
-        const float block_p = max(p - (n_ctx - 2.f), 0.f);
-        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
-        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     } else {
-        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     }
 
@@ -6400,10 +6399,7 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
 
-    const int mode = ((int32_t *) dst->op_params)[2];
-    const bool is_glm = mode & 4;
-
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
 }
 
 void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
diff --git a/ggml.c b/ggml.c
index d5ca0101a64b0bfcc99e5212eb200ec4353a60f4..3f72379c3553e27bc3f685f8756f163a4aa5f860 100644 (file)
--- a/ggml.c
+++ b/ggml.c
 // disable "possible loss of data" to avoid hundreds of casts
 // we should just be careful :)
 #pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnigns
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
 #endif
 
 #if defined(_WIN32)
@@ -306,12 +310,14 @@ typedef double ggml_float;
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <intrin.h>
 #else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
 #if !defined(__riscv)
 #include <immintrin.h>
 #endif
 #endif
 #endif
 #endif
+#endif
 
 #ifdef __riscv_v_intrinsic
 #include <riscv_vector.h>
@@ -18871,7 +18877,6 @@ static enum ggml_opt_result linesearch_backtracking(
                     // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
                     return count;
                 }
-                return count;
             }
         }