]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add IQ2 to test-backend-ops + refactoring (llama/4990)
authorGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 16:54:56 +0000 (18:54 +0200)
committerGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 18:44:11 +0000 (20:44 +0200)
* ggml : add IQ2 to test-backend-ops + refactoring

ggml-ci

* cuda : update supports_op for IQ2

ggml-ci

* ci : enable LLAMA_CUBLAS=1 for CUDA nodes

ggml-ci

* cuda : fix out-of-bounds-access in `mul_mat_vec_q`

ggml-ci

* tests : avoid creating RNGs for each Q tensor

ggml-ci

* tests : avoid creating RNGs for each tensor

ggml-ci

include/ggml/ggml.h
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-quants.c
src/ggml-quants.h
src/ggml.c
tests/test-backend-ops.cpp

index 27daf6fd1e12b6ccda8513a62e1e6a420e18930f..de8162b8135f3a208726ac6047cee9419a4be427 100644 (file)
@@ -2065,6 +2065,18 @@ extern "C" {
     // quantization
     //
 
+    // - ggml_quantize_init can be called multiple times with the same type
+    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free
+    //   automatically called by ggml_quantize_chunk for convenience
+    //
+    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
+    //   call this at the end of the program to avoid memory leaks
+    //
+    // note: these are thread-safe
+    //
+    GGML_API void ggml_quantize_init(enum ggml_type type);
+    GGML_API void ggml_quantize_free(void);
+
     // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
@@ -2078,13 +2090,13 @@ extern "C" {
     GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
 
+    // some quantization type cannot be used without an importance matrix
+    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
+
+    // calls ggml_quantize_init internally (i.e. can allocate memory)
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst,
             int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
-    // These are needed for IQ2_XS and IQ2_XXS quantizations
-    GGML_API void ggml_init_iq2_quantization(enum ggml_type type);
-    GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type);
-
     //
     // gguf
     //
index 4266250f926eee602c2aaf650190dde72a22fd47..ef518dae0909b9a9d8ee3c1b886f701222b712c4 100644 (file)
@@ -692,6 +692,8 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
 
 GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     switch (op->op) {
+        case GGML_OP_CPY:
+            return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
         default:
index 568c411afd3eed6026a401f3a5d61ba9ccc2a116..b2211d858c23aad6b8c1fe79a96f38b9c684afe5 100644 (file)
@@ -5131,10 +5131,10 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
+    for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
 
-        const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 
         const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
 
@@ -10918,6 +10918,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 if (a->ne[3] != b->ne[3]) {
                     return false;
                 }
+                ggml_type a_type = a->type;
+                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
+                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+                        return false;
+                    }
+                }
                 return true;
             } break;
         case GGML_OP_GET_ROWS:
index 31b053e33578743bb6e82c88fdb5764cb09c739c..7d2f033e9a0fe8b288e1183fc53c9c75b40120e5 100644 (file)
@@ -1274,7 +1274,12 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
     }
     float sumlx = 0;
     float suml2 = 0;
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 0; i < n; ++i) {
+#else
     for (int i = 0; i < n; ++i) {
+#endif
         int l = nearest_int(iscale * x[i]);
         l = MAX(-nmax, MIN(nmax-1, l));
         L[i] = l + nmax;
@@ -1649,7 +1654,12 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
     float max = x[0];
     float sum_w = weights ? weights[0] : x[0]*x[0];
     float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
     for (int i = 1; i < n; ++i) {
+#endif
         if (x[i] < min) min = x[i];
         if (x[i] > max) max = x[i];
         float w = weights ? weights[i] : x[i]*x[i];
@@ -1660,7 +1670,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
         min = 0;
     }
     if (max <= min) {
-        for (int i = 0; i < n; ++i) L[i] = 0;
+        memset(L, 0, n);
         *the_min = -min;
         return 0.f;
     }
@@ -1862,7 +1872,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
 
 size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
     (void)hist;
-    int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
     }
@@ -2181,7 +2191,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
 
 size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
     (void)hist;
-    int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
     }
@@ -2448,7 +2458,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
 
 size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
     (void)hist;
-    int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
     }
@@ -2771,7 +2781,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
 
 size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
     (void)hist;
-    int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
     }
@@ -3025,7 +3035,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
 
 size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
     (void)hist;
-    int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
     if (!quant_weights) {
         quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
     }
@@ -3072,7 +3082,7 @@ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int
     if (!quant_weights) {
         return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
     }
-    int row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
         quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
@@ -3116,7 +3126,7 @@ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int
     if (!quant_weights) {
         return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
     }
-    int row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
         quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
@@ -3169,7 +3179,7 @@ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int
     if (!quant_weights) {
         return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
     }
-    int row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
         quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
@@ -3221,7 +3231,7 @@ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int
     if (!quant_weights) {
         return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
     }
-    int row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
         quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
@@ -8565,7 +8575,7 @@ static int iq2_compare_func(const void * left, const void * right) {
     return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
 }
 
-static void q2xs_init_impl(int grid_size) {
+void iq2xs_init_impl(int grid_size) {
     const int gindex = iq2_data_index(grid_size);
     if (iq2_data[gindex].grid) {
         return;
@@ -8720,19 +8730,7 @@ static void q2xs_init_impl(int grid_size) {
     free(dist2);
 }
 
-void ggml_init_iq2_quantization(enum ggml_type type) {
-    if (type == GGML_TYPE_IQ2_XXS) {
-        q2xs_init_impl(256);
-    }
-    else if (type == GGML_TYPE_IQ2_XS) {
-        q2xs_init_impl(512);
-    }
-    else {
-        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
-    }
-}
-
-static void q2xs_deinit_impl(int grid_size) {
+void iq2xs_free_impl(int grid_size) {
     GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
     const int gindex = iq2_data_index(grid_size);
     if (iq2_data[gindex].grid) {
@@ -8742,18 +8740,6 @@ static void q2xs_deinit_impl(int grid_size) {
     }
 }
 
-void ggml_deinit_iq2_quantization(enum ggml_type type) {
-    if (type == GGML_TYPE_IQ2_XXS) {
-        q2xs_deinit_impl(256);
-    }
-    else if (type == GGML_TYPE_IQ2_XS) {
-        q2xs_deinit_impl(512);
-    }
-    else {
-        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
-    }
-}
-
 static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
         const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
     int num_neighbors = neighbours[0];
@@ -8786,10 +8772,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
     const int      * kmap_q2xs       = iq2_data[gindex].map;
     const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
 
-    GGML_ASSERT(quant_weights);
-    GGML_ASSERT(kgrid_q2xs);
-    GGML_ASSERT(kmap_q2xs);
-    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
     GGML_ASSERT(n%QK_K == 0);
 
     const int kMaxQ = 3;
@@ -9005,10 +8991,10 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
     const int      * kmap_q2xs       = iq2_data[gindex].map;
     const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
 
-    GGML_ASSERT(quant_weights);
-    GGML_ASSERT(kmap_q2xs);
-    GGML_ASSERT(kgrid_q2xs);
-    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
     GGML_ASSERT(n%QK_K == 0);
 
     const int kMaxQ = 3;
index d7fefdb5479117fbf74318e38e6c3542e14d24ed..7d7cf9178f76e7714fc2080578342fbdd5e5ee65 100644 (file)
@@ -257,3 +257,6 @@ size_t quantize_q4_0   (const float * src, void * dst, int nrows, int n_per_row,
 size_t quantize_q4_1   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q5_0   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q5_1   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+
+void iq2xs_init_impl(int grid_size);
+void iq2xs_free_impl(int grid_size);
index 35fd29a9ec2dc899525d8d0f1d20d5480b17f3f0..cbf2d4bddddb83ac106d79ca4216aa8f114efffd 100644 (file)
@@ -18524,6 +18524,28 @@ enum ggml_opt_result ggml_opt_resume_g(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+void ggml_quantize_init(enum ggml_type type) {
+    ggml_critical_section_start();
+
+    switch (type) {
+        case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
+        case GGML_TYPE_IQ2_XS:  iq2xs_init_impl(512); break;
+        default: // nothing
+            break;
+    }
+
+    ggml_critical_section_end();
+}
+
+void ggml_quantize_free(void) {
+    ggml_critical_section_start();
+
+    iq2xs_free_impl(256);
+    iq2xs_free_impl(512);
+
+    ggml_critical_section_end();
+}
+
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
     assert(k % QK4_0 == 0);
     const int nb = k / QK4_0;
@@ -18651,9 +18673,15 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK8_0*sizeof(block_q8_0));
 }
 
+bool ggml_quantize_requires_imatrix(enum ggml_type type) {
+    return
+        type == GGML_TYPE_IQ2_XXS ||
+        type == GGML_TYPE_IQ2_XS;
+}
+
 size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
         int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
-    (void)imatrix;
+    ggml_quantize_init(type); // this is noop if already initialized
     size_t result = 0;
     int n = nrows * n_per_row;
     switch (type) {
@@ -18766,13 +18794,13 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
             } break;
         case GGML_TYPE_F16:
             {
-                int elemsize = sizeof(ggml_fp16_t);
+                size_t elemsize = sizeof(ggml_fp16_t);
                 ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
                 result = n * elemsize;
             } break;
         case GGML_TYPE_F32:
             {
-                int elemsize = sizeof(float);
+                size_t elemsize = sizeof(float);
                 result = n * elemsize;
                 memcpy((uint8_t *)dst + start * elemsize, src + start, result);
             } break;
index 22a7856d46f417f91cee1d8602a475a7a1650c39..55ce14e0d902c50c613d98d6b6cc0ce211119019 100644 (file)
 #include <vector>
 
 static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
+    // static RNG initialization (revisit if n_threads stops being constant)
+    static const size_t n_threads = std::thread::hardware_concurrency();
+    static std::vector<std::default_random_engine> generators = []() {
+        std::random_device rd;
+        std::vector<std::default_random_engine> vec;
+        vec.reserve(n_threads);
+        //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
+        for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
+        return vec;
+    }();
+
     size_t size = ggml_nelements(tensor);
     std::vector<float> data(size);
 
-#if 0
-    static std::default_random_engine generator(1234);
-    std::uniform_real_distribution<float> distribution(min, max);
-
-    for (size_t i = 0; i < size; i++) {
-        data[i] = distribution(generator);
-    }
-#else
-    auto init_thread = [&](size_t start, size_t end) {
-        std::random_device rd;
-        std::default_random_engine generator(rd());
+    auto init_thread = [&](size_t ith, size_t start, size_t end) {
         std::uniform_real_distribution<float> distribution(min, max);
-
         for (size_t i = start; i < end; i++) {
-            data[i] = distribution(generator);
+            data[i] = distribution(generators[ith]);
         }
     };
 
-    size_t n_threads = std::thread::hardware_concurrency();
     std::vector<std::thread> threads;
     threads.reserve(n_threads);
     for (size_t i = 0; i < n_threads; i++) {
         size_t start =     i*size/n_threads;
         size_t end   = (i+1)*size/n_threads;
-        threads.emplace_back(init_thread, start, end);
+        threads.emplace_back(init_thread, i, start, end);
     }
     for (auto & t : threads) {
         t.join();
     }
-#endif
 
     if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
         ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
@@ -56,7 +54,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
         GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
         std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
         int64_t hist[16];
-        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, nullptr);
+        std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
+        const float * im = imatrix.data();
+        if (!ggml_quantize_requires_imatrix(tensor->type)) {
+            // when the imatrix is optional, we want to test both quantization with and without imatrix
+            // use one of the random numbers to decide
+            if (data[0] > 0.5f*(min + max)) {
+                im = nullptr;
+            }
+        }
+        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, im);
         ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
     } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
         // This is going to create some weird integers though.
@@ -1472,7 +1479,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_Q8_0,
         GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
         GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K
+        GGML_TYPE_Q6_K,
+        GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
     };
 
     // unary ops
@@ -1752,6 +1760,8 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    ggml_quantize_free();
+
     printf("\033[1;32mOK\033[0m\n");
     return 0;
 }