]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Clean up ggml-cuda.cu warnings when compiling with clang (for ROCM) (#4124)
authorKerfuffle <redacted>
Sat, 18 Nov 2023 15:11:18 +0000 (08:11 -0700)
committerGitHub <redacted>
Sat, 18 Nov 2023 15:11:18 +0000 (08:11 -0700)
* ggml-cuda.cu: Clean up warnings when compiling with clang

* ggml-cuda.cu: Move static items into anonymous namespace

* ggml-cuda.cu: Fix use of namespace start macro

* Revert "ggml-cuda.cu: Fix use of namespace start macro"

This reverts commit 26c11490266c096e3e5731e05270a8f73a5b2874.

* Revert "ggml-cuda.cu: Move static items into anonymous namespace"

This reverts commit e29757e0f7535d1ac314300f0324684cc785e06c.

ggml-cuda.cu

index 874ad9ac4e8eca03ffbc07443e6cf98b732bb87e..50e03de5007472b82f9e0bb667f098b1213bad83 100644 (file)
@@ -235,7 +235,7 @@ typedef float2 dfloat2;
 #endif //GGML_CUDA_F16
 
 static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
-    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
 
     int x32 = 0;
     x32 |= x16[0] <<  0;
@@ -245,7 +245,7 @@ static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const
 }
 
 static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
-    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
 
     int x32 = 0;
     x32 |= x16[0] <<  0;
@@ -255,11 +255,11 @@ static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, con
 }
 
 static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
-    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
 }
 
 static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
-    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
 }
 
 template<typename T>
@@ -469,7 +469,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
 #define MAX_STREAMS 8
-static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -2248,6 +2248,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
@@ -2259,7 +2260,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-
+    (void)x_qh; (void)x_sc;
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
     GGML_CUDA_ASSUME(k >= 0);
@@ -2268,7 +2269,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_0;
     const int kqsx = k % QI4_0;
 
-    const block_q4_0 * bx0 = (block_q4_0 *) vx;
+    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
 
     float * x_dmf = (float *) x_dm;
 
@@ -2306,9 +2307,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const float * x_dmf = (float *) x_dm;
+    const float * x_dmf = (const float *) x_dm;
 
     int u[2*VDR_Q4_0_Q8_1_MMQ];
 
@@ -2342,6 +2344,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
@@ -2353,6 +2356,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2362,7 +2366,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_1;
     const int kqsx = k % QI4_1;
 
-    const block_q4_1 * bx0 = (block_q4_1 *) vx;
+    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2397,6 +2401,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 
@@ -2434,6 +2439,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
@@ -2445,6 +2451,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2454,7 +2461,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_0;
     const int kqsx = k % QI5_0;
 
-    const block_q5_0 * bx0 = (block_q5_0 *) vx;
+    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2509,6 +2516,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
@@ -2548,6 +2556,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
@@ -2559,6 +2568,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset < nwarps);
@@ -2568,7 +2578,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_1;
     const int kqsx = k % QI5_1;
 
-    const block_q5_1 * bx0 = (block_q5_1 *) vx;
+    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2620,6 +2630,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
@@ -2654,6 +2665,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh; (void)x_sc;
 
     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
@@ -2665,6 +2677,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh; (void)x_sc;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2675,7 +2688,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kqsx = k % QI8_0;
     float * x_dmf = (float *) x_dm;
 
-    const block_q8_0 * bx0 = (block_q8_0 *) vx;
+    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2710,6 +2723,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
 
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
@@ -2743,6 +2757,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
@@ -2756,6 +2771,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -2765,7 +2781,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI2_K;
     const int kqsx = k % QI2_K;
 
-    const block_q2_K * bx0 = (block_q2_K *) vx;
+    const block_q2_K * bx0 = (const block_q2_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2813,6 +2829,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const int kbx = k / QI2_K;
     const int ky  = (k % QI2_K) * QR2_K;
@@ -2886,7 +2903,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI3_K;
     const int kqsx = k % QI3_K;
 
-    const block_q3_K * bx0 = (block_q3_K *) vx;
+    const block_q3_K * bx0 = (const block_q3_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -2967,7 +2984,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
 
-    const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 
     int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
 
@@ -3082,6 +3099,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
@@ -3095,6 +3113,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3104,7 +3123,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI4_K; // == 0 if QK_K == 256
     const int kqsx = k % QI4_K; // == k if QK_K == 256
 
-    const block_q4_K * bx0 = (block_q4_K *) vx;
+    const block_q4_K * bx0 = (const block_q4_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3149,7 +3168,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
 
-        const int * scales = (int *) bxi->scales;
+        const int * scales = (const int *) bxi->scales;
 
         const int ksc = k % (WARP_SIZE/8);
 
@@ -3164,6 +3183,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 
@@ -3263,6 +3283,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
@@ -3276,6 +3297,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3285,7 +3307,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI5_K; // == 0 if QK_K == 256
     const int kqsx = k % QI5_K; // == k if QK_K == 256
 
-    const block_q5_K * bx0 = (block_q5_K *) vx;
+    const block_q5_K * bx0 = (const block_q5_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3341,7 +3363,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
 
-        const int * scales = (int *) bxi->scales;
+        const int * scales = (const int *) bxi->scales;
 
         const int ksc = k % (WARP_SIZE/8);
 
@@ -3356,6 +3378,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
 
@@ -3392,6 +3415,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    (void)x_qh;
 
     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
@@ -3405,6 +3429,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    (void)x_qh;
 
     GGML_CUDA_ASSUME(i_offset >= 0);
     GGML_CUDA_ASSUME(i_offset <  nwarps);
@@ -3414,7 +3439,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const int kbx  = k / QI6_K; // == 0 if QK_K == 256
     const int kqsx = k % QI6_K; // == k if QK_K == 256
 
-    const block_q6_K * bx0 = (block_q6_K *) vx;
+    const block_q6_K * bx0 = (const block_q6_K *) vx;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -3476,6 +3501,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
 
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
@@ -3518,7 +3544,7 @@ static __device__ __forceinline__ void mul_mat_q(
     __shared__ int    tile_y_qs[mmq_x * WARP_SIZE];
     __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
 
-    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f};
+    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
 
     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
 
@@ -6023,18 +6049,18 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
     if (nb0 == ts && nb1 == ts*ne0/bs) {
         return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
-    } else if (nb0 == ts) {
+    }
+    if (nb0 == ts) {
         return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
-    } else {
-        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
-            if (r != cudaSuccess) return r;
-        }
-        return cudaSuccess;
     }
+    for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+        const void * rx = (const void *) ((const char *) x + i1*nb1);
+        void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+        // pretend the row is a matrix with cols=1
+        cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+        if (r != cudaSuccess) { return r; }
+    }
+    return cudaSuccess;
 }
 
 static void ggml_cuda_op_repeat(
@@ -6989,7 +7015,7 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
-    const int64_t nrows0 = ggml_nrows(src0);
+    // const int64_t nrows0 = ggml_nrows(src0);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
@@ -7090,7 +7116,7 @@ static void ggml_cuda_op_mul_mat(
         if (src0_on_device && src0_is_contiguous) {
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
-            const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
+            // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
             src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
@@ -7323,7 +7349,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (!g_cublas_loaded) return false;
+    if (!g_cublas_loaded) { return false; }
 
     const int64_t ne10 = src1->ne[0];
 
@@ -7401,7 +7427,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
-__global__ void k_compute_batched_ptrs(
+__global__ static void k_compute_batched_ptrs(
         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
         const void ** ptrs_src, void ** ptrs_dst,
         int ne12, int ne13,
@@ -8017,7 +8043,7 @@ void ggml_cuda_free_scratch() {
 }
 
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
-    if (!g_cublas_loaded) return false;
+    if (!g_cublas_loaded) { return false; }
 
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -8316,14 +8342,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
     UNUSED(cgraph);
 }
 
-static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+[[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     GGML_ASSERT(!"not implemented");
 
     UNUSED(backend);
     UNUSED(plan);
 }
 
-static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+[[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     GGML_ASSERT(!"not implemented");
 
     UNUSED(backend);
@@ -8339,8 +8365,9 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
-        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
+        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) {
             continue;
+        }
         assert(node->backend == GGML_BACKEND_GPU);
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             if (node->src[j] != nullptr) {