]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix compile error on Windows CUDA (#2207)
authorHoward Su <redacted>
Thu, 13 Jul 2023 13:58:09 +0000 (21:58 +0800)
committerGitHub <redacted>
Thu, 13 Jul 2023 13:58:09 +0000 (21:58 +0800)
ggml-cuda.cu

index dc4b773a66a4451dab6f209cc44751dc852adaeb..e0d5e9156315ba4ed74e385c819e75fe614d3652 100644 (file)
@@ -267,10 +267,9 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] * y[i%ky];
 }
 
-static const float GELU_COEF_A    = 0.044715f;
-static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
+    const float GELU_COEF_A    = 0.044715f;
+    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -2300,7 +2299,7 @@ inline void ggml_cuda_op_add(
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
-    const int64_t ne10 = src1->ne[0];
+    // const int64_t ne10 = src1->ne[0];
 
     // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {