]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Fix more int overflow during quant (PPL/CUDA). (llama/6563)
authorDAN™ <redacted>
Sun, 28 Apr 2024 22:38:44 +0000 (18:38 -0400)
committerGeorgi Gerganov <redacted>
Sat, 11 May 2024 18:30:08 +0000 (21:30 +0300)
* Fix more int overflow during quant.

* Fix some more int overflow in softmax.

* Revert back to int64_t.

src/ggml-cuda/convert.cu
src/ggml-cuda/softmax.cu

index b15e3578267b3837354d37352d01c700c31be12a..75e50c98561235c572365217281afeaf5a7ae247 100644 (file)
@@ -5,16 +5,16 @@
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
-    const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+    const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
 
     if (i >= k) {
         return;
     }
 
     const int64_t ib = i/qk; // block index
-    const int iqs = (i%qk)/qr; // quant index
-    const int iybs = i - i%qk; // y block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    const int64_t iqs = (i%qk)/qr; // quant index
+    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
@@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
 #if __CUDA_ARCH__ >= CC_PASCAL
     constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
 
-    const int   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
+    const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
     const int * x0 = ((int *) vx) + blockIdx.x * nint;
     half2 * y2 = (half2 *) (y + i0);
 
@@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
     const int64_t i = blockIdx.x;
 
     // assume 32 threads
-    const int tid = threadIdx.x;
-    const int il  = tid/8;
-    const int ir  = tid%8;
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
     const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
@@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
     const int64_t i = blockIdx.x;
 
     // assume 32 threads
-    const int tid = threadIdx.x;
-    const int il  = tid/8;
-    const int ir  = tid%8;
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
     const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
@@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
 template<typename dst_t>
 static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_q2_K * x = (const block_q2_K *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int n   = tid/32;
-    const int l   = tid - 32*n;
-    const int is  = 8*n + l/16;
+    const int64_t n   = tid/32;
+    const int64_t l   = tid - 32*n;
+    const int64_t is  = 8*n + l/16;
 
     const uint8_t q = x[i].qs[32*n + l];
     dst_t * y = yy + i*QK_K + 128*n;
@@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
 #else
-    const int is = tid/16;  // 0 or 1
-    const int il = tid%16;  // 0...15
+    const int64_t is = tid/16;  // 0 or 1
+    const int64_t il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
     dst_t * y = yy + i*QK_K + 16*is + il;
     float dall = __low2half(x[i].dm);
@@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
 template<typename dst_t>
 static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
 
 #if QK_K == 256
-    const int r = threadIdx.x/4;
-    const int tid = r/2;
-    const int is0 = r%2;
-    const int l0 = 16*is0 + 4*(threadIdx.x%4);
-    const int n = tid / 4;
-    const int j = tid - 4*n;
+    const int64_t r = threadIdx.x/4;
+    const int64_t tid = r/2;
+    const int64_t is0 = r%2;
+    const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int64_t n = tid / 4;
+    const int64_t j = tid - 4*n;
 
     uint8_t m = 1 << (4*n + j);
-    int is = 8*n + 2*j + is0;
+    int64_t is = 8*n + 2*j + is0;
     int shift = 2*j;
 
     int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
 #else
-    const int tid = threadIdx.x;
-    const int is  = tid/16;  // 0 or 1
-    const int il  = tid%16;  // 0...15
-    const int im  = il/8;    // 0...1
-    const int in  = il%8;    // 0...7
+    const int64_t tid = threadIdx.x;
+    const int64_t is  = tid/16;  // 0 or 1
+    const int64_t il  = tid%16;  // 0...15
+    const int64_t im  = il/8;    // 0...1
+    const int64_t in  = il%8;    // 0...7
 
     dst_t * y = yy + i*QK_K + 16*is + il;
 
@@ -227,15 +227,15 @@ template<typename dst_t>
 static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
 
 #if QK_K == 256
     // assume 32 threads
-    const int tid = threadIdx.x;
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int is  = 2*il;
-    const int n   = 4;
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t is  = 2*il;
+    const int64_t n   = 4;
 
     dst_t * y = yy + i*QK_K + 64*il + n*ir;
 
@@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
         y[l +32] = d2 * (q[l] >>  4) - m2;
     }
 #else
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
     const uint8_t * q = x[i].qs;
     dst_t * y = yy + i*QK_K;
     const float d = (float)x[i].dm[0];
@@ -268,14 +268,14 @@ template<typename dst_t>
 static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
 
 #if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
-    const int tid = threadIdx.x;
-    const int il  = tid/16;   // il is in 0...3
-    const int ir  = tid%16;   // ir is in 0...15
-    const int is  = 2*il;     // is is in 0...6
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/16;   // il is in 0...3
+    const int64_t ir  = tid%16;   // ir is in 0...15
+    const int64_t is  = 2*il;     // is is in 0...6
 
     dst_t * y = yy + i*QK_K + 64*il + 2*ir;
 
@@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 #else
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
     const uint8_t q = x[i].qs[tid];
-    const int im = tid/8;  // 0...3
-    const int in = tid%8;  // 0...7
-    const int is = tid/16; // 0 or 1
+    const int64_t im = tid/8;  // 0...3
+    const int64_t in = tid%8;  // 0...7
+    const int64_t is = tid/16; // 0 or 1
     const uint8_t h = x[i].qh[in] >> im;
     const float d = x[i].d;
     dst_t * y = yy + i*QK_K + tid;
@@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
 template<typename dst_t>
 static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * q2 = x[i].qs + 4*ib;
     const uint8_t  * aux8 = (const uint8_t *)q2;
@@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
 template<typename dst_t>
 static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq2_xs * x = (const block_iq2_xs *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * q2 = x[i].qs + 4*ib;
     const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
 template<typename dst_t>
 static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq2_s * x = (const block_iq2_s *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
     const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
 template<typename dst_t>
 static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t  * q3 = x[i].qs + 8*ib;
     const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
 template<typename dst_t>
 static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq3_s * x = (const block_iq3_s *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t * qs = x[i].qs + 8*ib;
     const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
 template<typename dst_t>
 static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq1_s * x = (const block_iq1_s  *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
     const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
 template<typename dst_t>
 static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq1_m * x = (const block_iq1_m  *) vx;
 
-    const int tid = threadIdx.x;
+    const int64_t tid = threadIdx.x;
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * sc = (const uint16_t *)x[i].scales;
     iq1m_scale_t scale;
     scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-    const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+    const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
     const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
     const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
     uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
@@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
 template<typename dst_t>
 static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
 
-    const int tid = threadIdx.x;
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 4*il;
     const uint8_t  * q4 = x[ib].qs + 4*il;
     const float d = (float)x[ib].d;
@@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
 #if QK_K != 64
 template<typename dst_t>
 static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
-    const int i   = blockIdx.x;
+    const int64_t i   = blockIdx.x;
     const block_iq4_xs * x = (const block_iq4_xs *)vx;
 
-    const int tid = threadIdx.x;
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 4*il;
     const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
     const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
index 9bda18e581c7506c8e3cb95fb5c6778951d918c8..fa8f987cf7c1d4e3c26fe60bca485556f9d85e02 100644 (file)
@@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
     // shared memory buffer to cache values between iterations:
-    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
+    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
 
     float max_val = -INFINITY;
 
@@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
             break;
         }
 
-        const int ix = rowx*ncols + col;
-        const int iy = rowy*ncols + col;
+        const int64_t ix = (int64_t)rowx*ncols + col;
+        const int64_t iy = (int64_t)rowy*ncols + col;
 
         const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
 
@@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
             return;
         }
 
-        const int idst = rowx*ncols + col;
+        const int64_t idst = (int64_t)rowx*ncols + col;
         dst[idst] = vals[col] * inv_sum;
     }
 }