]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama2c : fix segfault and alloc-dealloc-mismatch (#2913)
authorCebtenzzre <redacted>
Fri, 1 Sep 2023 09:03:49 +0000 (05:03 -0400)
committerGitHub <redacted>
Fri, 1 Sep 2023 09:03:49 +0000 (12:03 +0300)
* llama2c : fix segfault if vocab is not found

* llama2c : fix mismatch between new[] and delete

* llama2c : fix basename on Windows

* llama2c : use a destructor to prevent memory leaks

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

index e9e070b1fa32125c9229832bd80330b4a1b33910..0b03c9d2b46186db8cd96432c3d4bdf4b29f169e 100644 (file)
@@ -75,7 +75,7 @@ typedef struct {
     int seq_len; // max sequence length
 } Config;
 
-typedef struct {
+struct TransformerWeights {
     // token embedding table
     float* token_embedding_table;    // (vocab_size, dim)
     // weights for rmsnorms
@@ -97,7 +97,22 @@ typedef struct {
     // float* freq_cis_imag; // (seq_len, dim/2)
     // (optional) classifier weights for the logits, on the last layer
     float* wcls;
-} TransformerWeights;
+
+    ~TransformerWeights() {
+        delete[] token_embedding_table;
+        delete[] rms_att_weight;
+        delete[] rms_ffn_weight;
+        delete[] wq;
+        delete[] wk;
+        delete[] wv;
+        delete[] wo;
+        delete[] w1;
+        delete[] w2;
+        delete[] w3;
+        delete[] rms_final_weight;
+        delete[] wcls;
+    }
+};
 
 void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
     // we calloc instead of malloc to keep valgrind happy
@@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
     return 0;
 }
 
-void free_weights(TransformerWeights* w) {
-    delete w->token_embedding_table;
-    delete w->rms_att_weight;
-    delete w->rms_ffn_weight;
-    delete w->wq;
-    delete w->wk;
-    delete w->wv;
-    delete w->wo;
-    delete w->w1;
-    delete w->w2;
-    delete w->w3;
-    delete w->rms_final_weight;
-    if (w->wcls) delete w->wcls;
-}
-
 void print_sample_weights(TransformerWeights *w){
     printf("----- Quick print of first of the weight vales of all the variables\n");
     printf("%f\n", w->token_embedding_table[0]);
@@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
         // assume llama2.c vocabulary
         printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
         llama_file file(filename, "rb");
+        if (!file.fp) {
+            fprintf(stderr, "error: %s: %s\n", strerror(errno), filename);
+            exit(1);
+        }
         const int  n_vocab = config->vocab_size;
         /* uint32_t max_token_length =  */ file.read_u32(); // unused
         vocab->id_to_token.resize(n_vocab);
@@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
 }
 
 std::string basename(const std::string &path) {
-    size_t pos = path.find_last_of("/");
+    size_t pos = path.find_last_of("/\\");
     if (pos == std::string::npos) {
         return path;
     }
@@ -911,7 +915,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
     Config config;
-    TransformerWeights weights;
+    TransformerWeights weights = {};
     {
         FILE *file = fopen(params.fn_llama2c_model, "rb");
         if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
@@ -953,6 +957,5 @@ int main(int argc, char ** argv) {
     printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
 
     ggml_free(model.ctx);
-    free_weights(&weights);
     return 0;
 }