]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : add Command R Plus support (llama/6491)
authorCarolinabanana <redacted>
Tue, 9 Apr 2024 08:16:13 +0000 (09:16 +0100)
committerGeorgi Gerganov <redacted>
Tue, 9 Apr 2024 17:16:09 +0000 (20:16 +0300)
* Add Command R Plus GGUF

* Add Command R Plus GGUF

* Loading works up to LayerNorm2D

* Export new tensors in 1D so they are not quantized.

* Fix embedding layer based on Noeda's example

* Whitespace

* Add line

* Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda)

* dranger003: Fix block index overflow in CUDA dequantizing.

* Reverted blocked multiplication code as it still has issues and could affect other Llama arches

* export norms as f32

* fix overflow issues during quant and other cleanup

* Type convention

Co-authored-by: Georgi Gerganov <redacted>
* dranger003: Fix more int overflow during quant.

---------

Co-authored-by: S <redacted>
Co-authored-by: S <redacted>
Co-authored-by: slaren <redacted>
Co-authored-by: Georgi Gerganov <redacted>
12 files changed:
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-cuda/common.cuh
src/ggml-cuda/convert.cu
src/ggml-cuda/convert.cuh
src/ggml-cuda/dequantize.cuh
src/ggml-cuda/dmmv.cu
src/ggml-cuda/quantize.cu
src/ggml-cuda/quantize.cuh
src/ggml-quants.c
src/ggml-quants.h
src/ggml.c

index 5cef45c0ba4ad13fefe8450f874918e6e74ed354..abe3767f22418b9580bee4c00b435692a94fbb27 100644 (file)
@@ -332,8 +332,8 @@ extern "C" {
     GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
     GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
 
-    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
-    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
+    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n);
+    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n);
 
     struct ggml_object;
     struct ggml_context;
@@ -2210,9 +2210,9 @@ extern "C" {
             enum ggml_type   type,
                const float * src,
                       void * dst,
-                       int   start,
-                       int   nrows,
-                       int   n_per_row,
+                   int64_t   start,
+                   int64_t   nrows,
+                   int64_t   n_per_row,
                const float * imatrix);
 
     //
@@ -2377,8 +2377,8 @@ extern "C" {
 #else
 #define GGML_RESTRICT restrict
 #endif
-    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int k);
+    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
     typedef void (*ggml_vec_dot_t)   (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
                                       const void * GGML_RESTRICT y, size_t by, int nrc);
 
index ce28cb55d01b2f70d47dcefc88c624ec5ba5b598..bff8ad9d96e887aaf0c8206955642d1bcbf22eec 100644 (file)
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = id == ctx.device ? ne0 : row_diff;
+    int64_t ldc = id == ctx.device ? ne0 : row_diff;
 
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
-    const int nb2 = dst->nb[2];
-    const int nb3 = dst->nb[3];
+    const int64_t nb2 = dst->nb[2];
+    const int64_t nb3 = dst->nb[3];
 
     GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
     GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
index b98d7cbd0c1c501f9f2e8a1327596af9c0c923de..481065b2a3484b350f258b981a5b4efae291477a 100644 (file)
@@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 // TODO: move to ggml-common.h
 static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
 
 
 //////////////////////
index 18a31edc34f8354177a47683572c3352d9329aeb..ed4fa2748972b3fc3064bdfd8afc32510b5b01d1 100644 (file)
@@ -4,14 +4,14 @@
 #define CUDA_Q8_0_NE_ALIGN 2048
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
-    const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
 
     if (i >= k) {
         return;
     }
 
-    const int ib = i/qk; // block index
+    const int64_t ib = i/qk; // block index
     const int iqs = (i%qk)/qr; // quant index
     const int iybs = i - i%qk; // y block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
 }
 
 template <bool need_check>
-static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
 #if __CUDA_ARCH__ >= CC_PASCAL
     constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
 
@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
 template<typename dst_t>
 static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
 
     // assume 32 threads
     const int tid = threadIdx.x;
     const int il  = tid/8;
     const int ir  = tid%8;
-    const int ib = 8*i + ir;
+    const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
     }
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
 template<typename dst_t>
 static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
 
     // assume 32 threads
     const int tid = threadIdx.x;
     const int il  = tid/8;
     const int ir  = tid%8;
-    const int ib = 8*i + ir;
+    const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
     }
@@ -313,14 +313,14 @@ template<typename dst_t>
 static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
-    const int i = blockIdx.x;
+    const int64_t i = blockIdx.x;
 #if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
-    const int tid = threadIdx.x;
-    const int ip  = tid/32;   // ip is 0 or 1
-    const int il  = tid - 32*ip; // 0...32
-    const int is  = 8*ip + il/16;
+    const int64_t tid = threadIdx.x;
+    const int64_t ip  = tid/32;   // ip is 0 or 1
+    const int64_t il  = tid - 32*ip; // 0...32
+    const int64_t is  = 8*ip + il/16;
 
     dst_t * y = yy + i*QK_K + 128*ip + il;
 
@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
 #else
 
     // assume 32 threads
-    const int tid = threadIdx.x;
-    const int ip  = tid/16;         // 0 or 1
-    const int il  = tid - 16*ip;    // 0...15
+    const int64_t tid = threadIdx.x;
+    const int64_t ip  = tid/16;         // 0 or 1
+    const int64_t il  = tid - 16*ip;    // 0...15
 
     dst_t * y = yy + i*QK_K + 16*ip + il;
 
@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
 #endif
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
     const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
     dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
     if (k % CUDA_Q8_0_NE_ALIGN == 0) {
         const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
 }
 
 template<typename dst_t>
-static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb32 = k / 32;
     const int nb = (k + 255) / 256;
     dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
 }
 
 template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb32 = k / 32;
     const int nb = (k + 255) / 256;
     dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
 }
 
 template<typename dst_t>
-static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = (k + QK_K - 1) / QK_K;
     dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = (k + QK_K - 1) / QK_K;
 #if QK_K == 64
     dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
 }
 
 template <typename src_t, typename dst_t>
-static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
 }
 
 template <typename src_t, typename dst_t>
-static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
index db34c0be96449706c2b3e6d441ffbd9a54b312da..5394be9f161b3410cf62bf23eb04d05f611ecf17 100644 (file)
@@ -3,7 +3,7 @@
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
 template<typename T>
-using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
 
 typedef to_t_cuda_t<float> to_fp32_cuda_t;
 typedef to_t_cuda_t<half> to_fp16_cuda_t;
index b54400632a5ece6020e59835a52e03d592c1f48e..bd3c2d9db94639f87b92a7cc52185a633e8e8cb5 100644 (file)
@@ -1,6 +1,6 @@
 #include "common.cuh"
 
-static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
     const dfloat d = x[ib].d;
@@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
 #endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
     const dfloat d = __low2half(x[ib].dm);
@@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
 #endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
     const dfloat d = x[ib].d;
@@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
 #endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
     const dfloat d = __low2half(x[ib].dm);
@@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
 #endif // GGML_CUDA_F16
 }
 
-static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
     const dfloat d = x[ib].d;
index 0b17e3cb961f743d8db4263020d7dd1d49f57ab0..7313e3e175367594c67e012d50bd2292eb4b3cb1 100644 (file)
@@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
     }
 }
 
-static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const half * x = (const half *) vx;
 
     // automatic half -> float type cast if dfloat == float
@@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
     // qr = number of quantized weights per data value in x block
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
 
     for (int i = 0; i < ncols; i += iter_stride) {
         const int col = i + vals_per_iter*tid;
-        const int ib = (row*ncols + col)/qk; // x block index
+        const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
         const int iqs = (col%qk)/qr; // x quant index
         const int iybs = col - col%qk; // y block start index
 
index a1fbc9932122f21df44b8ad83be1203fcfcad86c..7578c4b6c7cab7553914f3c84b45ffcdff87664b 100644 (file)
@@ -1,20 +1,20 @@
 #include "quantize.cuh"
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
-    const int ix = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
+    const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (ix >= kx_padded) {
         return;
     }
 
-    const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+    const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
 
-    const int i_padded = iy*kx_padded + ix;
+    const int64_t i_padded = (int64_t)iy*kx_padded + ix;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
+    const int64_t ib = i_padded / QK8_1; // block index
+    const int64_t iqs = i_padded % QK8_1; // quant index
 
     const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
     float amax = fabsf(xi);
@@ -36,8 +36,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
-    const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
+    const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
     const dim3 num_blocks(block_num_x, ky, 1);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
index adb89c83aac85f3316bf8098ac4bf0bd9e155232..b37a4752f2d24e2de0b1241511f8085721ffc5f2 100644 (file)
@@ -2,4 +2,4 @@
 
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream);
+void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);
index f2e6c4bd1a3216b21a2d071bc0c60ce22c9ab777..32e84434a8c1b8b271ac44a75f76632726a1c99e 100644 (file)
@@ -544,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
 #endif
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
     static const int qk = QK4_0;
 
     assert(k % qk == 0);
@@ -581,12 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
     }
 }
 
-void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q4_0_reference(x, y, k);
 }
 
 
-void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
     const int qk = QK4_1;
 
     assert(k % qk == 0);
@@ -623,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
     }
 }
 
-void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q4_1_reference(x, y, k);
 }
 
-void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
     static const int qk = QK5_0;
 
     assert(k % qk == 0);
@@ -671,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
     }
 }
 
-void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q5_0_reference(x, y, k);
 }
 
-void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
     const int qk = QK5_1;
 
     assert(k % qk == 0);
@@ -719,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
     }
 }
 
-void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q5_1_reference(x, y, k);
 }
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
 
@@ -749,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict
     }
 }
 
-void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -938,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
 }
 
 // reference implementation for deterministic creation of model files
-void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
     assert(QK8_1 == 32);
     assert(k % QK8_1 == 0);
     const int nb = k / QK8_1;
@@ -973,7 +973,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict
     }
 }
 
-void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK8_1 == 0);
     const int nb = k / QK8_1;
 
@@ -1192,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
 #endif
 }
 
-void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
     static const int qk = QK4_0;
 
     assert(k % qk == 0);
@@ -1212,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int
     }
 }
 
-void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) {
     static const int qk = QK4_1;
 
     assert(k % qk == 0);
@@ -1233,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int
     }
 }
 
-void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) {
     static const int qk = QK5_0;
 
     assert(k % qk == 0);
@@ -1259,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int
     }
 }
 
-void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) {
     static const int qk = QK5_1;
 
     assert(k % qk == 0);
@@ -1286,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int
     }
 }
 
-void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
     static const int qk = QK8_0;
 
     assert(k % qk == 0);
@@ -1581,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
 
 //========================- 2-bit (de)-quantization
 
-void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1658,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
     }
 }
 
-void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -1704,7 +1704,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
     }
 }
 
-void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
     quantize_row_q2_K_reference(x, vy, k);
 }
 
@@ -1960,14 +1960,14 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
     }
 }
 
-size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
     if (!quant_weights) {
-        quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
+        quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row);
     }
     else {
         char * qrow = (char *)dst;
-        for (int row = 0; row < nrow; ++row) {
+        for (int64_t row = 0; row < nrow; ++row) {
             quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
             src += n_per_row;
             qrow += row_size;
@@ -1978,7 +1978,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow,
 
 //========================= 3-bit (de)-quantization
 
-void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -2092,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
 }
 
 #if QK_K == 256
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -2142,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
     }
 }
 #else
-void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     assert(QK_K == 64);
     const int nb = k / QK_K;
@@ -2175,11 +2175,11 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
 }
 #endif
 
-void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
     quantize_row_q3_K_reference(x, vy, k);
 }
 
-static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
+static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
 #if QK_K != 256
     (void)quant_weights;
     quantize_row_q3_K_reference(x, y, n_per_row);
@@ -2268,14 +2268,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
 #endif
 }
 
-size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
     if (!quant_weights) {
-        quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
+        quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row);
     }
     else {
         char * qrow = (char *)dst;
-        for (int row = 0; row < nrow; ++row) {
+        for (int64_t row = 0; row < nrow; ++row) {
             quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
             src += n_per_row;
             qrow += row_size;
@@ -2286,7 +2286,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow,
 
 // ====================== 4-bit (de)-quantization
 
-void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -2393,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
     }
 }
 
-void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -2432,19 +2432,19 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
     }
 }
 
-void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_q4_K * restrict y = vy;
     quantize_row_q4_K_reference(x, y, k);
 }
 
-static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
 #if QK_K != 256
     (void)quant_weights;
     quantize_row_q4_K_reference(x, y, n_per_row);
 #else
     assert(n_per_row % QK_K == 0);
-    const int nb = n_per_row / QK_K;
+    const int64_t nb = n_per_row / QK_K;
 
     uint8_t L[QK_K];
     uint8_t Laux[32];
@@ -2516,14 +2516,14 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
 #endif
 }
 
-size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
     if (!quant_weights) {
-        quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
+        quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row);
     }
     else {
         char * qrow = (char *)dst;
-        for (int row = 0; row < nrow; ++row) {
+        for (int64_t row = 0; row < nrow; ++row) {
             quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
             src += n_per_row;
             qrow += row_size;
@@ -2534,9 +2534,9 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow,
 
 // ====================== 5-bit (de)-quantization
 
-void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
 #if QK_K == 256
     uint8_t L[QK_K];
@@ -2676,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
     }
 }
 
-void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -2721,19 +2721,19 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
     }
 }
 
-void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_q5_K * restrict y = vy;
     quantize_row_q5_K_reference(x, y, k);
 }
 
-static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
 #if QK_K != 256
     (void)quant_weights;
     quantize_row_q5_K_reference(x, y, n_per_row);
 #else
     assert(n_per_row % QK_K == 0);
-    const int nb = n_per_row / QK_K;
+    const int64_t nb = n_per_row / QK_K;
 
     uint8_t L[QK_K];
     uint8_t Laux[32];
@@ -2825,14 +2825,14 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
 #endif
 }
 
-size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
     if (!quant_weights) {
-        quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
+        quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row);
     }
     else {
         char * qrow = (char *)dst;
-        for (int row = 0; row < nrow; ++row) {
+        for (int64_t row = 0; row < nrow; ++row) {
             quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
             src += n_per_row;
             qrow += row_size;
@@ -2843,9 +2843,9 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow,
 
 // ====================== 6-bit (de)-quantization
 
-void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     int8_t L[QK_K];
     float   scales[QK_K/16];
@@ -2925,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
     }
 }
 
-void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -2972,19 +2972,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
     }
 }
 
-void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_q6_K * restrict y = vy;
     quantize_row_q6_K_reference(x, y, k);
 }
 
-static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
 #if QK_K != 256
     (void)quant_weights;
     quantize_row_q6_K_reference(x, y, n_per_row);
 #else
     assert(n_per_row % QK_K == 0);
-    const int nb = n_per_row / QK_K;
+    const int64_t nb = n_per_row / QK_K;
 
     int8_t L[QK_K];
     float   scales[QK_K/16];
@@ -3067,14 +3067,14 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
 #endif
 }
 
-size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
     if (!quant_weights) {
-        quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
+        quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row);
     }
     else {
         char * qrow = (char *)dst;
-        for (int row = 0; row < nrow; ++row) {
+        for (int64_t row = 0; row < nrow; ++row) {
             quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
             src += n_per_row;
             qrow += row_size;
@@ -3083,7 +3083,7 @@ size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow,
     return nrow * row_size;
 }
 
-static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK4_0 == 32, "QK4_0 must be 32");
 
     if (!quant_weights) {
@@ -3098,7 +3098,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
     for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
     float sigma2 = sum_x2/n_per_row;
 
-    const int nb = n_per_row/QK4_0;
+    const int64_t nb = n_per_row/QK4_0;
     for (int ib = 0; ib < nb; ++ib) {
         const float * xb = x + QK4_0 * ib;
         const float * qw = quant_weights + QK4_0 * ib;
@@ -3111,14 +3111,14 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
     }
 }
 
-size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
-        quantize_row_q4_0_reference(src, dst, nrow*n_per_row);
+        quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
     }
     size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += row_size;
@@ -3126,7 +3126,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow,
     return nrow * row_size;
 }
 
-static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK4_1 == 32, "QK4_1 must be 32");
 
     if (!quant_weights) {
@@ -3141,7 +3141,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
     for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
     float sigma2 = sum_x2/n_per_row;
 
-    const int nb = n_per_row/QK4_1;
+    const int64_t nb = n_per_row/QK4_1;
     for (int ib = 0; ib < nb; ++ib) {
         const float * xb = x + QK4_1 * ib;
         const float * qw = quant_weights + QK4_1 * ib;
@@ -3156,14 +3156,14 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
     }
 }
 
-size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
-        quantize_row_q4_1_reference(src, dst, nrow*n_per_row);
+        quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
     }
     size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += row_size;
@@ -3171,7 +3171,7 @@ size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow,
     return nrow * row_size;
 }
 
-static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK5_0 == 32, "QK5_0 must be 32");
 
     if (!quant_weights) {
@@ -3186,7 +3186,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
     for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
     float sigma2 = sum_x2/n_per_row;
 
-    const int nb = n_per_row/QK5_0;
+    const int64_t nb = n_per_row/QK5_0;
     for (int ib = 0; ib < nb; ++ib) {
         const float * xb = x + QK5_0 * ib;
         const float * qw = quant_weights + QK5_0 * ib;
@@ -3210,14 +3210,14 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
     }
 }
 
-size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
-        quantize_row_q5_0_reference(src, dst, nrow*n_per_row);
+        quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
     }
     size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += row_size;
@@ -3225,7 +3225,7 @@ size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow,
     return nrow * row_size;
 }
 
-static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) {
+static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
     static_assert(QK5_1 == 32, "QK5_1 must be 32");
 
     if (!quant_weights) {
@@ -3240,7 +3240,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
     for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
     float sigma2 = sum_x2/n_per_row;
 
-    const int nb = n_per_row/QK5_1;
+    const int64_t nb = n_per_row/QK5_1;
     for (int ib = 0; ib < nb; ++ib) {
         const float * xb = x + QK5_1 * ib;
         const float * qw = quant_weights + QK5_1 * ib;
@@ -3263,14 +3263,14 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
     }
 }
 
-size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     if (!quant_weights) {
-        quantize_row_q5_1_reference(src, dst, nrow*n_per_row);
+        quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row);
         return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
     }
     size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += row_size;
@@ -3278,18 +3278,18 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow,
     return nrow * row_size;
 }
 
-size_t quantize_q8_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     (void)quant_weights; // not used
     const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
-    quantize_row_q8_0_reference(src, dst, nrow*n_per_row);
+    quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row);
     return nrow * row_size;
 }
 
 // ====================== "True" 2-bit (de)-quantization
 
-void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     uint32_t aux32[2];
     const uint8_t * aux8 = (const uint8_t *)aux32;
@@ -3315,9 +3315,9 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
 
 // ====================== 2.3125 bpw (de)-quantization
 
-void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     float db[2];
 
@@ -3342,9 +3342,9 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
 
 // ====================== 2.5625 bpw (de)-quantization
 
-void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     float db[2];
 
@@ -3374,9 +3374,9 @@ void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, in
 
 // ====================== 3.0625 bpw (de)-quantization
 
-void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     uint32_t aux32;
 
@@ -3406,9 +3406,9 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
 
 // ====================== 3.3125 bpw (de)-quantization
 
-void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -3449,9 +3449,9 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
 
 // ====================== 1.5625 bpw (de)-quantization
 
-void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
+void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -3474,9 +3474,9 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
     }
 }
 
-void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) {
+void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     float delta[4];
     uint16_t idx[4];
@@ -3535,9 +3535,9 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
 
 static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
-void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
+void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) {
     assert(k % QK4_NL == 0);
-    const int nb = k / QK4_NL;
+    const int64_t nb = k / QK4_NL;
 
     for (int i = 0; i < nb; i++) {
 
@@ -3553,12 +3553,12 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
     }
 }
 
-void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
+void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
 #if QK_K == 64
     dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
 #else
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -3582,9 +3582,9 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
 
 //===================================== Q8_K ==============================================
 
-void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
 
@@ -3621,9 +3621,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
     }
 }
 
-void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) {
     assert(k % QK_K == 0);
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 
     for (int i = 0; i < nb; i++) {
         for (int j = 0; j < QK_K; ++j) {
@@ -3632,7 +3632,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int
     }
 }
 
-void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q8_K_reference(x, y, k);
 }
 
@@ -10648,7 +10648,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
 
@@ -10664,7 +10664,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
 
     const int kMaxQ = 3;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     block_iq2_xxs * y = vy;
 
@@ -10821,7 +10821,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
     }
 }
 
-static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
 
@@ -10837,7 +10837,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
 
     const int kMaxQ = 3;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     block_iq2_xs * y = vy;
 
@@ -11001,11 +11001,11 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
     }
 }
 
-size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq2_xxs);
@@ -11013,11 +11013,11 @@ size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nro
     return nrow * nblock * sizeof(block_iq2_xxs);
 }
 
-size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq2_xs);
@@ -11242,7 +11242,7 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
     return grid_index;
 }
 
-static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n,
         const float * restrict quant_weights) {
 
     const int gindex = iq3_data_index(grid_size);
@@ -11259,7 +11259,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
 
     const int kMaxQ = 8;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     ggml_fp16_t * dh;
     uint8_t * qs;
@@ -11455,11 +11455,11 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
     }
 }
 
-size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq3_xxs);
@@ -11467,13 +11467,13 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nro
     return nrow * nblock * sizeof(block_iq3_xxs);
 }
 
-void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_iq3_xxs * restrict y = vy;
     quantize_row_iq3_xxs_reference(x, y, k);
 }
 
-void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
+void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
 }
@@ -11504,7 +11504,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
 
     const int kMaxQ = 8;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     block_iq3_s * y = vy;
 
@@ -11661,9 +11661,9 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
 }
 
 #define IQ3S_BLOCK_SIZE 32
-size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     float scales[QK_K/IQ3S_BLOCK_SIZE];
     float weight[IQ3S_BLOCK_SIZE];
     float xval[IQ3S_BLOCK_SIZE];
@@ -11674,7 +11674,7 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow,
     bool   is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
     uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
                 scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
         src += n_per_row;
@@ -11683,13 +11683,13 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow,
     return nrow * nblock * sizeof(block_iq3_s);
 }
 
-void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_iq3_s * restrict y = vy;
     quantize_row_iq3_s_reference(x, y, k);
 }
 
-void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
+void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq3_s(x, y, 1, k, NULL);
 }
@@ -11822,7 +11822,7 @@ static int iq1_sort_helper(const void * left, const void * right) {
 
 #define IQ1S_BLOCK_SIZE 32
 #define IQ1M_BLOCK_SIZE 16
-static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
         float    * scales,
         float    * weight,
         float    * sumx,
@@ -11846,7 +11846,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
 
     block_iq1_s * y = vy;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     const int block_size = IQ1S_BLOCK_SIZE;
 
@@ -11980,7 +11980,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     float  scales[QK_K/IQ1S_BLOCK_SIZE];
     float  weight[IQ1S_BLOCK_SIZE];
@@ -11990,9 +11990,9 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow,
     float  pairs[2*IQ1S_BLOCK_SIZE];
     uint16_t index[IQ1S_BLOCK_SIZE/8];
     int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq1_s);
@@ -12000,7 +12000,7 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow,
     return nrow * nblock * sizeof(block_iq1_s);
 }
 
-static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
         float    * scales,
         float    * weight,
         float    * pairs,
@@ -12022,7 +12022,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
 
     block_iq1_m * y = vy;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     const int block_size = IQ1M_BLOCK_SIZE;
 
@@ -12265,7 +12265,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
     float  scales[QK_K/IQ1M_BLOCK_SIZE];
     float  weight[IQ1M_BLOCK_SIZE];
@@ -12273,9 +12273,9 @@ size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow,
     float  pairs[2*IQ1M_BLOCK_SIZE];
     uint16_t index[IQ1M_BLOCK_SIZE/8];
     int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq1_m);
@@ -12407,16 +12407,16 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
     }
 }
 
-size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK4_NL == 0);
-    int nblock = n_per_row/QK4_NL;
+    int64_t nblock = n_per_row/QK4_NL;
     char * qrow = (char *)dst;
     uint8_t L[QK4_NL];
     float weight[QK4_NL];
     uint16_t unused_h;
     uint8_t * unused_l = NULL;
     float scale;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
         for (int ibl = 0; ibl < nblock; ++ibl) {
             const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
@@ -12429,9 +12429,9 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
     return nrow * nblock * sizeof(block_iq4_nl);
 }
 
-void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) {
     GGML_ASSERT(k%QK4_NL == 0);
-    int nblock = k/QK4_NL;
+    int64_t nblock = k/QK4_NL;
     uint8_t L[QK4_NL];
     float weight[QK4_NL];
     uint16_t unused_h;
@@ -12444,22 +12444,22 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
     }
 }
 
-void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
+void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
     assert(k % QK4_NL == 0);
     quantize_row_iq4_nl(x, y, k);
 }
 
-size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
 #if QK_K == 64
     return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights);
 #else
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
     uint8_t L[QK_K];
     float weight[32];
     float scales[QK_K/32];
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
         for (int ibl = 0; ibl < nblock; ++ibl) {
             const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
@@ -12473,20 +12473,20 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
 #endif
 }
 
-void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_iq4_xs * restrict y = vy;
     quantize_row_iq4_xs_reference(x, y, k);
 }
 
-void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) {
+void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq4_xs(x, y, 1, k, NULL);
 }
 
 // =============================== 2.5625 bpw
 
-static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
 
@@ -12501,7 +12501,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
 
     const int kMaxQ = 3;
 
-    const int nbl = n/QK_K;
+    const int64_t nbl = n/QK_K;
 
     block_iq2_s * y = vy;
 
@@ -12654,11 +12654,11 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
     }
 }
 
-size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
-    int nblock = n_per_row/QK_K;
+    int64_t nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
-    for (int row = 0; row < nrow; ++row) {
+    for (int64_t row = 0; row < nrow; ++row) {
         quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq2_s);
@@ -12666,12 +12666,12 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow,
     return nrow * nblock * sizeof(block_iq2_s);
 }
 
-void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int k) {
+void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
     assert(k % QK_K == 0);
     quantize_iq2_s(x, y, 1, k, NULL);
 }
 
-void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int k) {
+void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
     assert(k % QK_K == 0);
     block_iq2_s * restrict y = vy;
     quantize_row_iq2_s_reference(x, y, k);
index ac1091c3d3b66b8ca083a25b17479870ec94a1b5..4d436a8f06b3e5081205dd568eb12d17443d4ee0 100644 (file)
@@ -12,70 +12,70 @@ extern "C" {
 #endif
 
 // Quantization
-void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int k);
-void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int k);
-void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int k);
-void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int k);
-void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int k);
-void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int k);
-
-void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k);
-void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int k);
-void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int k);
-void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int k);
-void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k);
-void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
-
-void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
-void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int k);
-void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs  * GGML_RESTRICT y, int k);
-void quantize_row_iq3_s_reference  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int k);
-void quantize_row_iq2_s_reference  (const float * GGML_RESTRICT x, block_iq2_s   * GGML_RESTRICT y, int k);
-
-void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-
-void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-
-void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq3_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-void quantize_row_iq2_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs  * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s_reference  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s_reference  (const float * GGML_RESTRICT x, block_iq2_s   * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
 // Dequantization
-void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-
-void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-
-void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq1_m  (const block_iq1_m   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_m  (const block_iq1_m   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
 // Dot product
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -101,26 +101,26 @@ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
 void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
-size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq2_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq1_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq1_m  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_iq3_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-
-size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
-size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
+size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_m  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
 void iq2xs_init_impl(enum ggml_type type);
 void iq2xs_free_impl(enum ggml_type type);
index c9b0a6a0ef776af3a453d21c0575df97a8cc807a..793b67f4c70209e37098c26d1833abd3d5a9b4a6 100644 (file)
@@ -338,14 +338,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
     return GGML_FP32_TO_FP16(x);
 }
 
-void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) {
-    for (int i = 0; i < n; i++) {
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
+    for (int64_t i = 0; i < n; i++) {
         y[i] = GGML_FP16_TO_FP32(x[i]);
     }
 }
 
-void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
-    int i = 0;
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
+    int64_t i = 0;
 #if defined(__F16C__)
     for (; i + 7 < n; i += 8) {
         __m256 x_vec = _mm256_loadu_ps(x + i);
@@ -20331,11 +20331,11 @@ size_t ggml_quantize_chunk(
         enum ggml_type   type,
            const float * src,
                   void * dst,
-                   int   start,
-                   int   nrows,
-                   int   n_per_row,
+               int64_t   start,
+               int64_t   nrows,
+               int64_t   n_per_row,
            const float * imatrix) {
-    const int n = nrows * n_per_row;
+    const int64_t n = (int64_t) nrows * n_per_row;
 
     if (ggml_quantize_requires_imatrix(type)) {
         GGML_ASSERT(imatrix != NULL);