]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Leverage mmap for offloading tensors to GPU (#1597)
authorHoward Su <redacted>
Mon, 12 Jun 2023 12:44:16 +0000 (20:44 +0800)
committerGitHub <redacted>
Mon, 12 Jun 2023 12:44:16 +0000 (14:44 +0200)
* Rebase to latest

* Show progress

* Add assert to make sure we only allocate temp buffer for non-CPU backend tensor

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml-cuda.cu
ggml-cuda.h
ggml-opencl.cpp
ggml-opencl.h
llama.cpp

index 4f2195f77e9843266d1e661e66866c97745ae233..3b9a5ddfb0d8f8a04399307a67aa4b3e27cabea0 100644 (file)
@@ -1713,8 +1713,7 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     (void) dst;
 }
 
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
-    FILE * fp = fopen(fname, "rb");
+void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
     int nrows = ggml_nrows(tensor);
     const size_t nb1 = tensor->nb[1];
     ggml_backend backend = tensor->backend;
@@ -1748,35 +1747,19 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
 
         int64_t nrows_split = row_high - row_low;
 
-        const size_t offset_split = offset + row_low*nb1;
+        const size_t offset_split = row_low*nb1;
         const size_t size = ggml_nbytes_split(tensor, nrows_split);
 
         void * buf;
         CUDA_CHECK(cudaMalloc(&buf, size));
-        void * buf_host = malloc(size);
-
-#ifdef _WIN32
-        int ret = _fseeki64(fp, (__int64) offset_split, SEEK_SET);
-#else
-        int ret = fseek(fp, (long) offset_split, SEEK_SET);
-#endif
-        GGML_ASSERT(ret == 0); // same
-
-        size_t ret2 = fread(buf_host, size, 1, fp);
-        if (ret2 != 1) {
-            fprintf(stderr, "unexpectedly reached end of file");
-            exit(1);
-        }
+        void * buf_host = (char*)data + offset_split;
 
         cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
-        cudaDeviceSynchronize();
 
-        free(buf_host);
         extra->data_device[id] = buf;
     }
 
     tensor->extra = extra;
-    fclose(fp);
 }
 
 void ggml_cuda_free_data(struct ggml_tensor * tensor) {
index 3b74e32e2592720d6853ed8e99493982337628ad..fde6d4085bf29001d00529f4c1e250c00a46f640 100644 (file)
@@ -24,7 +24,8 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void   ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
+void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
+
 void   ggml_cuda_free_data(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
index 7b6daf4a87e8571bd00cb774ec45f7f7d8a84eab..5df922abd720e062bd27297a066c088df51df072 100644 (file)
@@ -1167,7 +1167,7 @@ size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct g
     return 0;
 }
 
-void ggml_cl_transform_tensor(ggml_tensor * tensor) {
+void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
     const int64_t ne0 = tensor->ne[0];
     const int64_t ne1 = tensor->ne[1];
     const int64_t ne2 = tensor->ne[2];
@@ -1179,6 +1179,7 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
     size_t q_size;
     cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
 
+    tensor->data = data;
     // copy tensor to device
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = 0; i2 < ne2; i2++) {
@@ -1190,35 +1191,5 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
     CL_CHECK(clFinish(queue));
 
     tensor->data = dst;
-    tensor->backend = GGML_BACKEND_GPU;
-}
-
-void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
-    cl_int err;
-    FILE * fp = fopen(fname, "rb");
-
-    const size_t size = ggml_nbytes(tensor);
-
-    cl_mem dst;
-    CL_CHECK((dst = clCreateBuffer(context, CL_MEM_READ_ONLY, size, nullptr, &err), err));
-    void * buf_host = malloc(size);
-
-#ifdef _WIN32
-    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
-#else
-    int ret = fseek(fp, (long) offset, SEEK_SET);
-#endif
-    GGML_ASSERT(ret == 0); // same
-
-    size_t ret2 = fread(buf_host, size, 1, fp);
-    if (ret2 != 1) {
-        fprintf(stderr, "unexpectedly reached end of file");
-        exit(1);
-    }
-
-    clEnqueueWriteBuffer(queue, dst, CL_TRUE, 0, size, buf_host, 0, nullptr, nullptr);
-
-    tensor->data = dst;
-    free(buf_host);
-    fclose(fp);
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 }
index bf95e5cd0b9de860e9fb4a0d5ea3eecd412754d8..a92b445c9d7660da6ed5dfcbc4cce9ae7a5b9827 100644 (file)
@@ -18,8 +18,7 @@ void   ggml_cl_host_free(void * ptr);
 
 void ggml_cl_free_data(const struct ggml_tensor* tensor);
 
-void ggml_cl_transform_tensor(struct ggml_tensor * tensor);
-void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, size_t offset);
+void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
index e100e2bc98bddf3681623dcabd1965a4404881c8..a9a7794ae56605158a27fb0415d0b94bf3632810 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -707,6 +707,9 @@ struct llama_model_loader {
 
     struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
         struct ggml_tensor * tensor;
+        if (backend != GGML_BACKEND_CPU) {
+            ggml_set_no_alloc(ggml_ctx, true);
+        }
         if (lt.ne.size() == 2) {
             tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
         } else {
@@ -716,6 +719,9 @@ struct llama_model_loader {
         ggml_set_name(tensor, lt.name.c_str());
         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
 
+        if (backend != GGML_BACKEND_CPU) {
+            ggml_set_no_alloc(ggml_ctx, use_mmap);
+        }
         tensor->backend = backend;
         lt.ggml_tensor = tensor;
         num_ggml_tensors_created++;
@@ -731,6 +737,7 @@ struct llama_model_loader {
     void load_all_data(llama_progress_callback progress_callback, void *  progress_callback_user_data, llama_mlock * lmlock) {
         size_t data_size = 0;
         size_t prefetch_size = 0;
+        size_t lock_size = 0;
         for (const llama_load_tensor & lt : tensors_map.tensors) {
             data_size += lt.size;
             if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
@@ -740,11 +747,6 @@ struct llama_model_loader {
 
         if (use_mmap) {
             mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size));
-            if (!lmlock) {
-                // Don't call the callback since the actual loading will be lazy
-                // and we can't measure it.
-                progress_callback = NULL;
-            }
             if (lmlock) {
                 lmlock->init(mapping->addr);
             }
@@ -752,20 +754,49 @@ struct llama_model_loader {
 
         size_t done_size = 0;
         for (llama_load_tensor & lt : tensors_map.tensors) {
-            if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) {
-                continue;
-            }
             if (progress_callback) {
                 progress_callback((float) done_size / data_size, progress_callback_user_data);
             }
             LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already
             lt.data = (uint8_t *) lt.ggml_tensor->data;
+
+            // allocate temp buffer if not using mmap
+            if (!use_mmap && lt.data == NULL) {
+                GGML_ASSERT(lt.ggml_tensor->backend != GGML_BACKEND_CPU);
+                lt.data = (uint8_t*)malloc(ggml_nbytes(lt.ggml_tensor));
+            }
+
             load_data_for(lt);
-            lt.ggml_tensor->data = lt.data;
-            done_size += lt.size;
-            if (use_mmap && lmlock) {
-                lmlock->grow_to(done_size);
+
+            switch(lt.ggml_tensor->backend) {
+                case GGML_BACKEND_CPU:
+                    lt.ggml_tensor->data = lt.data;
+                    if (use_mmap && lmlock) {
+                        lock_size += lt.size;
+                        lmlock->grow_to(lock_size);
+                    }
+                    break;
+#if defined(GGML_USE_CUBLAS)
+                case GGML_BACKEND_GPU:
+                case GGML_BACKEND_GPU_SPLIT:
+                    ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor);
+                    if (!use_mmap) {
+                        free(lt.data);
+                    }
+                    break;
+#elif defined(GGML_USE_CLBLAST)
+                case GGML_BACKEND_GPU:
+                    ggml_cl_transform_tensor(lt.data, lt.ggml_tensor);
+                    if (!use_mmap) {
+                        free(lt.data);
+                    }
+                    break;
+#endif
+                default:
+                    continue;
             }
+
+            done_size += lt.size;
         }
     }
 
@@ -1141,7 +1172,7 @@ static void llama_model_load_internal(
             if (backend == GGML_BACKEND_GPU) {
                 vram_weights +=
                     ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)             +
-                    ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
+                    ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
                     ggml_nbytes(layer.w1)             + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
             }
         }
@@ -1196,58 +1227,14 @@ static void llama_model_load_internal(
         model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
     }
 
-    ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
-
 #if defined(GGML_USE_CUBLAS)
     {
         ggml_cuda_set_tensor_split(tensor_split);
-
-        size_t done_size = 0;
-        size_t data_size = 0;
-        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            data_size += lt.size;
-            if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
-                done_size += lt.size;
-            }
-        }
-        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            ggml_backend backend = lt.ggml_tensor->backend;
-            if (backend != GGML_BACKEND_GPU && backend != GGML_BACKEND_GPU_SPLIT) {
-                continue;
-            }
-            if (progress_callback) {
-                progress_callback((float) done_size / data_size, progress_callback_user_data);
-            }
-            ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
-            done_size += lt.size;
-        }
-    }
-#elif defined(GGML_USE_CLBLAST)
-    {
-        size_t done_size = 0;
-        size_t data_size = 0;
-        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            data_size += lt.size;
-            if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
-                done_size += lt.size;
-            }
-        }
-        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            if (lt.ggml_tensor->backend != GGML_BACKEND_GPU) {
-                continue;
-            }
-            if (progress_callback) {
-                progress_callback((float) done_size / data_size, progress_callback_user_data);
-            }
-            ggml_cl_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
-            done_size += lt.size;
-        }
     }
-#else
-    (void) n_batch;
-    (void) tensor_split;
 #endif
 
+    ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
+
     if (progress_callback) {
         progress_callback(1.0f, progress_callback_user_data);
     }