]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : handle saving/loading null layers in recurrent memory (#14675)
authorl3utterfly <redacted>
Wed, 23 Jul 2025 08:16:41 +0000 (16:16 +0800)
committerGitHub <redacted>
Wed, 23 Jul 2025 08:16:41 +0000 (11:16 +0300)
* Update llama-memory-recurrent.cpp

handle saving/loading null layers in recurrent memory

* fixed styling issues and updated comments

* fix styling issue

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Sigbjørn Skjæret <redacted>
src/llama-memory-recurrent.cpp

index 1e1a7a9b31e4638bcf027d862c70a7154a0da611..c0c2ec084dc1447787e9c3e95aa64f572244eb0c 100644 (file)
@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
     // Iterate and write all the keys first, each row is a cell
     // Get whole range at a time
     for (uint32_t il = 0; il < n_layer; ++il) {
+        // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+        if (r_l[il] == nullptr) continue;
 
         // Write key type
         const int32_t r_type_i = (int32_t)r_l[il]->type;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
 
     if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+            if (s_l[il] == nullptr) continue;
 
             // Write value type
             const int32_t s_type_i = (int32_t)s_l[il]->type;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
         // When v is transposed, we also need the element size and get the element ranges from each row
         const uint32_t mem_size = size;
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+            if (s_l[il] == nullptr) continue;
+
             const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
     for (uint32_t il = 0; il < n_layer; ++il) {
+        // skip null layers
+        if (r_l[il] == nullptr) continue;
 
         // Read type of key
         int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 
     if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers
+            if (s_l[il] == nullptr) continue;
 
             // Read type of value
             int32_t s_type_i_ref;
             io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
             const int32_t s_type_i = (int32_t)s_l[il]->type;
+
             if (s_type_i != s_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
                 return false;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
     } else {
         // For each layer, read the values for each cell (transposed)
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers
+            if (s_l[il] == nullptr) continue;
+
             const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Read type of value