]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
add basic tensor data validation function (#6884)
authorslaren <redacted>
Fri, 26 Apr 2024 16:39:58 +0000 (18:39 +0200)
committerGitHub <redacted>
Fri, 26 Apr 2024 16:39:58 +0000 (18:39 +0200)
* add basic tensor data validation function

* add --check-tensors command line argument

tensor validation is disabled by default and can be enabled by adding
`--check-tensors` to the command line arguments.

quantize always validates tensors.

common/common.cpp
common/common.h
ggml-quants.c
ggml.h
llama.cpp
llama.h

index 97f55b053eee116494b6f2e359b9f28de5170c25..9ab8aef7ec832460640f9f9e7154867a8c4e831c 100644 (file)
@@ -1089,6 +1089,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.n_print = std::stoi(argv[i]);
         return true;
     }
+    if (arg == "--check-tensors") {
+        params.check_tensors = true;
+        return true;
+    }
     if (arg == "--ppl-output-type") {
         if (++i >= argc) {
             invalid_param = true;
@@ -1554,6 +1558,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
     printf("  -ptc N, --print-token-count N\n");
     printf("                        print token count every N tokens (default: %d)\n", params.n_print);
+    printf("  --check-tensors       check model tensor data for invalid values\n");
     printf("\n");
 #ifndef LOG_DISABLE_LOGS
     log_print_usage();
@@ -1774,6 +1779,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
     mparams.tensor_split    = params.tensor_split;
     mparams.use_mmap        = params.use_mmap;
     mparams.use_mlock       = params.use_mlock;
+    mparams.check_tensors   = params.check_tensors;
     if (params.kv_overrides.empty()) {
         mparams.kv_overrides = NULL;
     } else {
index 87361e8e915008489454797a92ce5f164b844dae..0005f143ef44d20f7ef52ef1bf998c8520bd450c 100644 (file)
@@ -161,6 +161,7 @@ struct gpt_params {
     bool dump_kv_cache     = false; // dump the KV cache contents for debugging purposes
     bool no_kv_offload     = false; // disable KV offloading
     bool warmup            = true;  // warmup run
+    bool check_tensors     = false; // validate tensor data
 
     std::string cache_type_k = "f16"; // KV cache data type for the K
     std::string cache_type_v = "f16"; // KV cache data type for the V
index c147531df1cd50c6e2578da042467685b0013e76..444d1e55ebd5404f70ce62f32f5ad08bbbdebd2d 100644 (file)
@@ -12383,3 +12383,287 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
     block_iq2_s * restrict y = vy;
     quantize_row_iq2_s_reference(x, y, k);
 }
+
+static bool validate_float(float f, size_t i) {
+    if (isinf(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+        return false;
+    }
+
+    if (isnan(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+        return false;
+    }
+
+    return true;
+}
+
+static bool isinf_fp16(ggml_fp16_t f) {
+    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
+}
+
+static bool isnan_fp16(ggml_fp16_t f) {
+    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
+}
+
+static bool validate_fp16(ggml_fp16_t f, size_t i) {
+    if (isinf_fp16(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+        return false;
+    }
+
+    if (isnan_fp16(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+        return false;
+    }
+
+    return true;
+}
+
+#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_fp16(q[i].d, i)) { \
+            return false; \
+        } \
+    }
+
+#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
+            return false; \
+        } \
+    }
+
+bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
+    if (type < 0 || type >= GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+        return false;
+    }
+
+    if (nbytes % ggml_type_size(type) != 0) {
+        fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
+        return false;
+    }
+
+    const size_t nb = nbytes/ggml_type_size(type);
+
+    switch (type) {
+        case GGML_TYPE_F16:
+            {
+                const ggml_fp16_t * f = (const ggml_fp16_t *) data;
+                size_t i = 0;
+#if defined(__AVX2__)
+                for (; i + 15 < nb; i += 16) {
+                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
+                    __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
+                    int mask = _mm256_movemask_epi8(cmp);
+                    if (mask) {
+                        for (size_t j = 0; j < 16; ++j) {
+                            if (!validate_fp16(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#elif defined(__ARM_NEON)
+                for (; i + 7 < nb; i += 8) {
+                    uint16x8_t v = vld1q_u16(f + i);
+                    uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
+                    uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
+                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
+                    if (mask) {
+                        for (size_t j = 0; j < 8; ++j) {
+                            if (!validate_fp16(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#endif
+                for (; i < nb; ++i) {
+                    if (!validate_fp16(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                const float * f = (const float *) data;
+                size_t i = 0;
+#if defined(__AVX2__)
+                for (; i + 7 < nb; i += 8) {
+                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
+                    __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
+                    int mask = _mm256_movemask_epi8(cmp);
+                    if (mask) {
+                        for (size_t j = 0; j < 8; ++j) {
+                            if (!validate_float(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#elif defined(__ARM_NEON)
+                for (; i + 3 < nb; i += 4) {
+                    uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
+                    uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
+                    uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
+                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
+                    if (mask) {
+                        for (size_t j = 0; j < 4; ++j) {
+                            if (!validate_float(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#endif
+                for (; i < nb; ++i) {
+                    if (!validate_float(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_F64:
+            {
+                const double * f = (const double *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    if (!validate_float(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
+            } break;
+        case GGML_TYPE_Q5_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
+            } break;
+        case GGML_TYPE_Q5_1:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
+            } break;
+        case GGML_TYPE_Q8_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
+            } break;
+        case GGML_TYPE_Q2_K:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+            } break;
+        case GGML_TYPE_Q3_K:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
+            } break;
+        case GGML_TYPE_Q4_K:
+            {
+            #ifdef GGML_QKK_64
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
+            #else
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
+            #endif
+            } break;
+        case GGML_TYPE_Q5_K:
+            {
+            #ifdef GGML_QKK_64
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
+            #else
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
+            #endif
+            } break;
+        case GGML_TYPE_Q6_K:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
+            } break;
+        case GGML_TYPE_Q8_K:
+            {
+                const block_q8_K * q = (const block_q8_K *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    if (!validate_float(q[i].d, i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_IQ1_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ1_M:
+            {
+                const block_iq1_m * q = (const block_iq1_m *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                #if QK_K == 64
+                    if (!validate_fp16(q[i].d, i)) {
+                        return false;
+                    }
+                #else
+                    iq1m_scale_t scale;
+                    const uint16_t * sc = (const uint16_t *)q[i].scales;
+                    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+                    if (!validate_fp16(scale.f16, i)) {
+                        return false;
+                    }
+                #endif
+                }
+            } break;
+        case GGML_TYPE_IQ2_XXS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
+            } break;
+        case GGML_TYPE_IQ2_XS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
+            } break;
+        case GGML_TYPE_IQ2_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ3_XXS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
+            } break;
+
+        case GGML_TYPE_IQ3_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ4_XS:
+        #if QK_K != 64
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
+            } break;
+        #endif
+        // with QK_K == 64, iq4_xs is iq4_nl
+        case GGML_TYPE_IQ4_NL:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_I64:
+            // nothing to validate
+            break;
+        default:
+            {
+                fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+                return false;
+            }
+    }
+
+    return true;
+}
diff --git a/ggml.h b/ggml.h
index 4d1d77fe9330cdb2e31b1231db6e80dea3704dee..86e5a8dc57ad198237b275efb2213b7b1c085f82 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -762,6 +762,8 @@ extern "C" {
     // use this to compute the memory overhead of a tensor
     GGML_API size_t ggml_tensor_overhead(void);
 
+    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
+
     // main
 
     GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
index d728bd499d38f7084bc5dc1b27df76880ed1e8a5..e5e640016d9a8f0cbae610f8963eddc088443380 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -75,6 +75,7 @@
 #include <forward_list>
 #include <fstream>
 #include <functional>
+#include <future>
 #include <initializer_list>
 #include <locale>
 #include <map>
@@ -2985,6 +2986,7 @@ struct llama_model_loader {
     size_t  n_bytes    = 0;
 
     bool use_mmap = false;
+    bool check_tensors;
 
     llama_files files;
     llama_ftype ftype;
@@ -3018,7 +3020,7 @@ struct llama_model_loader {
     std::string arch_name;
     LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
 
-    llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) {
+    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
         int trace = 0;
         if (getenv("LLAMA_TRACE")) {
             trace = atoi(getenv("LLAMA_TRACE"));
@@ -3223,6 +3225,7 @@ struct llama_model_loader {
         }
 
         this->use_mmap = use_mmap;
+        this->check_tensors = check_tensors;
     }
 
     ~llama_model_loader() {
@@ -3481,6 +3484,10 @@ struct llama_model_loader {
             file->seek(w.offs, SEEK_SET);
             file->read_raw(cur->data, ggml_nbytes(cur));
         }
+
+        if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
+            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
+        }
     }
 
     size_t size_done = 0;
@@ -3497,6 +3504,8 @@ struct llama_model_loader {
         GGML_ASSERT(size_data != 0 && "call init_mappings() first");
 
         std::vector<no_init<uint8_t>> read_buf;
+        std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
+
         for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
             const auto * weight = get_weight(ggml_get_name(cur));
             if (weight == nullptr) {
@@ -3518,37 +3527,66 @@ struct llama_model_loader {
                 if (bufs_mmap.count(weight->idx)) {
                     buf_mmap = bufs_mmap.at(weight->idx);
                 }
+                uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
+
+                if (check_tensors) {
+                    validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
+                        return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
+                    }));
+                }
+
                 GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
                 if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs);
+                    ggml_backend_tensor_alloc(buf_mmap, cur, data);
                     if (lmlocks) {
                         const auto & lmlock = lmlocks->at(weight->idx);
-                        lmlock->grow_to(weight->offs + ggml_nbytes(cur));
+                        lmlock->grow_to(weight->offs + n_size);
                     }
 
                     auto & mmap_used = mmaps_used[weight->idx];
                     mmap_used.first  = std::min(mmap_used.first,  weight->offs);
                     mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
                 } else {
-                    ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size);
+                    ggml_backend_tensor_set(cur, data, 0, n_size);
                 }
             } else {
                 GGML_ASSERT(weight->idx < files.size());
                 const auto & file = files.at(weight->idx);
                 if (ggml_backend_buffer_is_host(cur->buffer)) {
                     file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(cur->data, ggml_nbytes(cur));
+                    file->read_raw(cur->data, n_size);
+                    if (check_tensors) {
+                        validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
+                            return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
+                        }));
+                    }
                 } else {
-                    read_buf.resize(ggml_nbytes(cur));
+                    read_buf.resize(n_size);
                     file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(read_buf.data(), ggml_nbytes(cur));
+                    file->read_raw(read_buf.data(), n_size);
                     ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
+                    if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
+                        throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
+                    }
                 }
             }
 
             size_done += n_size;
         }
 
+        // check validation results
+        bool validation_failed = false;
+        for (auto & future : validation_result) {
+            auto result = future.get();
+            if (!result.second) {
+                LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
+                validation_failed = true;
+            }
+        }
+        if (validation_failed) {
+            throw std::runtime_error("found tensors with invalid data");
+        }
+
         // check if this is the last call and do final cleanup
         if (size_done >= size_data) {
             // unmap offloaded tensors and metadata
@@ -5975,7 +6013,7 @@ static bool llm_load_tensors(
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
     try {
-        llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
+        llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
 
         model.hparams.vocab_only = params.vocab_only;
 
@@ -14360,14 +14398,20 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
 }
 
 static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
-    std::mutex mutex;
-    int64_t counter = 0;
-    size_t new_size = 0;
     if (nthread < 2) {
         // single-thread
-        return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
+        size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
+        if (!ggml_validate_row_data(new_type, new_data, new_size)) {
+            throw std::runtime_error("quantized data validation failed");
+        }
+        return new_size;
     }
-    auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size,
+
+    std::mutex mutex;
+    int64_t counter = 0;
+    size_t new_size = 0;
+    bool valid = true;
+    auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
             nrows, n_per_row, imatrix]() {
         const int64_t nrows_per_chunk = chunk_size / n_per_row;
         size_t local_size = 0;
@@ -14382,7 +14426,17 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
             }
             lock.unlock();
             const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
-            local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
+            size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
+            local_size += this_size;
+
+            // validate the quantized data
+            const size_t row_size  = ggml_row_size(new_type, n_per_row);
+            void * this_data = (char *) new_data + first_row * row_size;
+            if (!ggml_validate_row_data(new_type, this_data, this_size)) {
+                std::unique_lock<std::mutex> lock(mutex);
+                valid = false;
+                break;
+            }
         }
     };
     for (int it = 0; it < nthread - 1; ++it) {
@@ -14391,6 +14445,9 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
     compute();
     for (auto & w : workers) { w.join(); }
     workers.clear();
+    if (!valid) {
+        throw std::runtime_error("quantized data validation failed");
+    }
     return new_size;
 }
 
@@ -14453,7 +14510,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
         kv_overrides = v->data();
     }
-    llama_model_loader ml(fname_inp, use_mmap, kv_overrides);
+    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
     ml.init_mappings(false); // no prefetching
 
     llama_model model;
@@ -14814,7 +14871,7 @@ static int llama_apply_lora_from_file_internal(
     std::unique_ptr<llama_model_loader> ml;
     if (path_base_model) {
         LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
-        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr));
+        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
         ml->init_mappings(/*prefetch*/ false); // no prefetching
     }
 
@@ -15073,6 +15130,7 @@ struct llama_model_params llama_model_default_params() {
         /*.vocab_only                  =*/ false,
         /*.use_mmap                    =*/ true,
         /*.use_mlock                   =*/ false,
+        /*.check_tensors               =*/ false,
     };
 
 #ifdef GGML_USE_METAL
diff --git a/llama.h b/llama.h
index 8aa763672be1b864536dbc122f6729eae01dc92a..0a79fa765566aa6d397be2d66a96c3747f8fa0ac 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -232,9 +232,10 @@ extern "C" {
         const struct llama_model_kv_override * kv_overrides;
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
-        bool vocab_only; // only load the vocabulary, no weights
-        bool use_mmap;   // use mmap if possible
-        bool use_mlock;  // force system to keep model in RAM
+        bool vocab_only;    // only load the vocabulary, no weights
+        bool use_mmap;      // use mmap if possible
+        bool use_mlock;     // force system to keep model in RAM
+        bool check_tensors; // validate model tensor data
     };
 
     struct llama_context_params {