From: cebtenzzre Date: Thu, 2 Nov 2023 05:49:44 +0000 (-0400) Subject: cuda : fix RoPE after #2268 (#3897) X-Git-Tag: upstream/0.0.4488~3022 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2fffa0d61fa10e4b466e78cabcc6a4e16717b580;p=pkg%2Fggml%2Fsources%2Fllama.cpp cuda : fix RoPE after #2268 (#3897) --- diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 61cd1747..57a528ed 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4539,7 +4539,7 @@ static __global__ void rope( const int i2 = row/p_delta_rows; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, -col/ncols); + const float theta_base = p*powf(freq_base, -float(col)/ncols); float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -4566,8 +4566,8 @@ static __global__ void rope_neox( const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - // simplified from `(row * ncols + col) * (-1 / ncols)` - const float cur_rot = -col/ncols - row; + // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero + const float cur_rot = -float(col)/ncols; const int p = has_pos ? pos[i2] : 0; const float theta_base = p*powf(freq_base, cur_rot);