]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : fix RoPE after #2268 (#3897)
authorcebtenzzre <redacted>
Thu, 2 Nov 2023 05:49:44 +0000 (01:49 -0400)
committerGitHub <redacted>
Thu, 2 Nov 2023 05:49:44 +0000 (07:49 +0200)
ggml-cuda.cu

index 61cd1747cac4fc1f917e644031bb01bc80e66f0e..57a528ede23ed24da855e4ca2e4cfb87e2b9aeca 100644 (file)
@@ -4539,7 +4539,7 @@ static __global__ void rope(
     const int i2 = row/p_delta_rows;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, -col/ncols);
+    const float theta_base = p*powf(freq_base, -float(col)/ncols);
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -4566,8 +4566,8 @@ static __global__ void rope_neox(
     const int i = row*ncols + col/2;
     const int i2 = row/p_delta_rows;
 
-    // simplified from `(row * ncols + col) * (-1 / ncols)`
-    const float cur_rot = -col/ncols - row;
+    // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
+    const float cur_rot = -float(col)/ncols;
 
     const int p = has_pos ? pos[i2] : 0;
     const float theta_base = p*powf(freq_base, cur_rot);