]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : synchronize before get/set session data (#6911)
authorslaren <redacted>
Thu, 25 Apr 2024 15:59:03 +0000 (17:59 +0200)
committerGitHub <redacted>
Thu, 25 Apr 2024 15:59:03 +0000 (17:59 +0200)
llama.cpp

index a43cde109ef3143d82ba536c7fb7cf03937558eb..0c01a353999b3ca8314bcbf4f371c523f6f8d4b1 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -16131,6 +16131,8 @@ struct llama_data_file_context : llama_data_context {
  *
 */
 static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
+    llama_synchronize(ctx);
+
     // copy rng
     {
         std::ostringstream rng_ss;
@@ -16283,6 +16285,8 @@ size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
 
 // Sets the state reading from the specified source address
 size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
+    llama_synchronize(ctx);
+
     const uint8_t * inp = src;
 
     // set rng
@@ -16587,6 +16591,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
 }
 
 static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
+    llama_synchronize(ctx);
+
     const auto & kv_self = ctx->kv_self;
     GGML_ASSERT(!kv_self.recurrent); // not implemented
 
@@ -16704,6 +16710,8 @@ size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_s
 }
 
 size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
+    llama_synchronize(ctx);
+
     auto & kv_self = ctx->kv_self;
     GGML_ASSERT(!kv_self.recurrent); // not implemented