]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Fix models with output size != 32000 (#2480)
authorJohannes Gäßler <redacted>
Wed, 2 Aug 2023 14:48:10 +0000 (16:48 +0200)
committerGitHub <redacted>
Wed, 2 Aug 2023 14:48:10 +0000 (16:48 +0200)
CMakeLists.txt
ggml-cuda.cu

index 1d4e63f3e72c9eb12c62878731f790eb4e62191f..d085bc835cf38b886d9a2f36ca717e1c25232e3f 100644 (file)
@@ -280,8 +280,8 @@ if (LLAMA_CUBLAS)
         # 52 == lowest CUDA 12 standard
         # 60 == f16 CUDA intrinsics
         # 61 == integer CUDA intrinsics
-        # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
-        if (LLAMA_CUDA_DMMV_F16)
+        # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
+        if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
             set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
         else()
             set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
index f11fbe57c11123736bc07ccb2d9dd16092f3b55f..a4dd6bb9df99769e90e3898423884383861ca984 100644 (file)
@@ -162,7 +162,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
 typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
 typedef void (*load_tiles_cuda_t)(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row);
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
 typedef float (*vec_dot_q_mul_mat_cuda_t)(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
@@ -1404,9 +1404,9 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 **
     *x_dm = tile_x_d;
 }
 
-static __device__ __forceinline__ void load_tiles_q4_0(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1420,7 +1420,11 @@ static __device__ __forceinline__ void load_tiles_q4_0(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1433,6 +1437,7 @@ static __device__ __forceinline__ void load_tiles_q4_0(
 
 // #pragma unroll
 //     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_0) {
+//         FIXME out-of-bounds
 //         const int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
 
 //         if (i >= GGML_CUDA_MMQ_Y) {
@@ -1513,9 +1518,9 @@ static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 **
     *x_dm = tile_x_dm;
 }
 
-static __device__ __forceinline__ void load_tiles_q4_1(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1529,7 +1534,11 @@ static __device__ __forceinline__ void load_tiles_q4_1(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1541,7 +1550,11 @@ static __device__ __forceinline__ void load_tiles_q4_1(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_1) {
-        const int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -1617,9 +1630,9 @@ static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 **
     *x_dm = tile_x_d;
 }
 
-static __device__ __forceinline__ void load_tiles_q5_0(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1633,7 +1646,11 @@ static __device__ __forceinline__ void load_tiles_q5_0(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1645,7 +1662,11 @@ static __device__ __forceinline__ void load_tiles_q5_0(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
-        const int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -1733,9 +1754,9 @@ static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 **
     *x_dm = tile_x_dm;
 }
 
-static __device__ __forceinline__ void load_tiles_q5_1(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1749,7 +1770,11 @@ static __device__ __forceinline__ void load_tiles_q5_1(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1761,7 +1786,11 @@ static __device__ __forceinline__ void load_tiles_q5_1(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_1) {
-        const int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -1824,9 +1853,9 @@ static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 **
     *x_dm = tile_x_d;
 }
 
-static __device__ __forceinline__ void load_tiles_q8_0(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1840,7 +1869,11 @@ static __device__ __forceinline__ void load_tiles_q8_0(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1853,6 +1886,7 @@ static __device__ __forceinline__ void load_tiles_q8_0(
 
 // #pragma unroll
 //     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI8_0) {
+//         FIXME out-of-bounds
 //         const int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
 
 // #if GGML_CUDA_MMQ_Y < 64
@@ -1947,9 +1981,9 @@ static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-static __device__ __forceinline__ void load_tiles_q2_K(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -1963,7 +1997,11 @@ static __device__ __forceinline__ void load_tiles_q2_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -1975,7 +2013,11 @@ static __device__ __forceinline__ void load_tiles_q2_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI2_K) {
-        const int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -1984,7 +2026,11 @@ static __device__ __forceinline__ void load_tiles_q2_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
-        const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
 
@@ -2099,9 +2145,9 @@ static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-static __device__ __forceinline__ void load_tiles_q3_K(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -2115,7 +2161,11 @@ static __device__ __forceinline__ void load_tiles_q3_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -2127,7 +2177,11 @@ static __device__ __forceinline__ void load_tiles_q3_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
-        const int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -2136,7 +2190,11 @@ static __device__ __forceinline__ void load_tiles_q3_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
-        const int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
 
@@ -2145,7 +2203,11 @@ static __device__ __forceinline__ void load_tiles_q3_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
-        const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
 
@@ -2320,9 +2382,9 @@ static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-static __device__ __forceinline__ void load_tiles_q4_K(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -2336,7 +2398,11 @@ static __device__ __forceinline__ void load_tiles_q4_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -2348,7 +2414,11 @@ static __device__ __forceinline__ void load_tiles_q4_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
-        const int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -2357,7 +2427,11 @@ static __device__ __forceinline__ void load_tiles_q4_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
 
@@ -2548,9 +2622,9 @@ static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-static __device__ __forceinline__ void load_tiles_q5_K(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -2564,7 +2638,11 @@ static __device__ __forceinline__ void load_tiles_q5_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -2576,7 +2654,11 @@ static __device__ __forceinline__ void load_tiles_q5_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
-        const int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -2585,7 +2667,11 @@ static __device__ __forceinline__ void load_tiles_q5_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
-        const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI5_K/4);
 
@@ -2594,7 +2680,11 @@ static __device__ __forceinline__ void load_tiles_q5_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
 
@@ -2717,9 +2807,9 @@ static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-static __device__ __forceinline__ void load_tiles_q6_K(
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
     __builtin_assume(i_offset <  8);
@@ -2733,7 +2823,11 @@ static __device__ __forceinline__ void load_tiles_q6_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
-        const int i = i0 + i_offset;
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
 
@@ -2745,7 +2839,11 @@ static __device__ __forceinline__ void load_tiles_q6_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
-        const int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
@@ -2754,7 +2852,11 @@ static __device__ __forceinline__ void load_tiles_q6_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
-        const int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI6_K/2);
 
@@ -2763,7 +2865,11 @@ static __device__ __forceinline__ void load_tiles_q6_K(
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
 
@@ -2849,7 +2955,7 @@ static __global__ void mul_mat_q(
     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
 
         load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
-                   tid_y, tid_x, blocks_per_row_x);
+                   tid_y, nrows_x-row_x_0-1, tid_x, blocks_per_row_x);
 
         for (int ir = 0; ir < qr; ++ir) {
             const int kqs = ir*WARP_SIZE + tid_x;
@@ -2873,7 +2979,7 @@ static __global__ void mul_mat_q(
 
         __syncthreads();
 
-#if __CUDA_ARCH__ >= 700 // TODO: actually test this with compute capability 7.X cards
+#if __CUDA_ARCH__ >= 700 // Unrolling the loop is slower on Pascal
 #pragma unroll
 #endif // __CUDA_ARCH__ >= 700
         for (int k = 0; k < WARP_SIZE/vdr; ++k) {
@@ -3609,8 +3715,14 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q4_1_q8_1_cuda(
@@ -3621,8 +3733,14 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q5_0_q8_1_cuda(
@@ -3633,8 +3751,14 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q5_1_q8_1_cuda(
@@ -3645,8 +3769,14 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q8_0_q8_1_cuda(
@@ -3657,8 +3787,14 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q2_K_q8_1_cuda(
@@ -3669,8 +3805,14 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q3_K_q8_1_cuda(
@@ -3681,8 +3823,14 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q4_K_q8_1_cuda(
@@ -3693,8 +3841,14 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q5_K_q8_1_cuda(
@@ -3705,8 +3859,14 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_q6_K_q8_1_cuda(
@@ -3717,8 +3877,14 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-    mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
 }
 
 static void ggml_mul_mat_p021_f16_f32_cuda(
@@ -4664,8 +4830,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
             row_low -= row_low % GGML_CUDA_MMQ_Y;
 
-            row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
-            row_high -= row_high % GGML_CUDA_MMQ_Y;
+            if (id == g_device_count - 1) {
+                row_high = nrows0;
+            } else {
+                row_high = nrows0*g_tensor_split[id + 1];
+                row_high -= row_high % GGML_CUDA_MMQ_Y;
+            }
         } else {
             row_low = 0;
             row_high = nrows0*i02_divisor;
@@ -5145,8 +5315,12 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
             row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
             row_low -= row_low % GGML_CUDA_MMQ_Y;
 
-            row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
-            row_high -= row_high % GGML_CUDA_MMQ_Y;
+            if (id == g_device_count - 1) {
+                row_high = nrows;
+            } else {
+                row_high = nrows*g_tensor_split[id + 1];
+                row_high -= row_high % GGML_CUDA_MMQ_Y;
+            }
         } else {
             GGML_ASSERT(false);
         }