]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : loading models directly into VRAM, norm calculation on GPU, broadcasting for...
authorJohannes Gäßler <redacted>
Sat, 20 May 2023 12:19:28 +0000 (14:19 +0200)
committerGitHub <redacted>
Sat, 20 May 2023 12:19:28 +0000 (15:19 +0300)
* Broadcasting for ggml_mul

* CUDA kernel for ggml_mul, norms in VRAM

* GPU weights not in RAM, direct loading with cuFile

* fixup! GPU weights not in RAM, direct loading with cuFile

* fixup! GPU weights not in RAM, direct loading with cuFile

* define default model path once, sync path with readme (#1366)

* ~7% faster Q5_1 AVX2 code (#1477)

* convert.py: Support models which are stored in a single pytorch_model.bin (#1469)

* Support models in a single pytorch_model.bin

* Remove spurious line with typo

* benchmark-matmul: Print the average of the test results (#1490)

* Remove unused n_parts parameter (#1509)

* Fixes #1511 lambda issue for w64devkit (mingw) (#1513)

* Fix for w64devkit and mingw

* make kv_f16 the default for api users (#1517)

* minor : fix compile warnings

* readme : adds WizardLM to the list of supported models (#1485)

* main : make reverse prompt option act as a stop token in non-interactive mode (#1032)

* Make reverse prompt option act as a stop token in non-interactive scenarios

* Making requested review changes

* Update gpt_params_parse and fix a merge error

* Revert "Update gpt_params_parse and fix a merge error"

This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8.

* Update gpt_params_parse and fix a merge error take 2

* examples : add persistent chat (#1495)

* examples : add persistent chat

* examples : fix whitespace

---------

Co-authored-by: Georgi Gerganov <redacted>
* tests : add missing header

* ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508)

* ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0

* llama : bump LLAMA_FILE_VERSION to 3

* cuda : update Q4 and Q8 dequantize kernels

* ggml : fix AVX dot products

* readme : update performance table + hot topics

* ggml : fix scalar implementation of Q4_1 dot

* llama : fix compile warnings in llama_set_state_data()

* llama : fix name shadowing and C4146 (#1526)

* Fix name shadowing and C4146

* Fix if macros not using defined when required

* Update llama-util.h

Co-authored-by: github-actions[bot] <redacted>
* Update llama-util.h

Co-authored-by: github-actions[bot] <redacted>
* Code style

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: github-actions[bot] <redacted>
Co-authored-by: Georgi Gerganov <redacted>
* Fix for mingw (#1462)

* llama : add llama_init_backend() API (close #1527)

* feature : add blis and other BLAS implementation support (#1502)

* feature: add blis support

* feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927

* fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake

* Fix typo in INTEGER

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
* Revert "feature : add blis and other BLAS implementation support (#1502)"

This reverts commit 07e9ace0f9da424d82e75df969642522880feb92.

* GPU weights not in RAM, direct loading with cuFile

* llama : code style fixes + progress print fix

* ggml : ggml_mul better broadcast support

* cmake : workarounds for cufile when CMake version < 3.25

* gg rebase fixup

* Loop in llama.cpp, fixed progress callback

* Attempt clang-tidy fix

* llama : fix vram size computation

* Add forgotten fclose()

---------

Co-authored-by: András Salamon <redacted>
Co-authored-by: Ilya Kurdyukov <redacted>
Co-authored-by: Tom Jobbins <redacted>
Co-authored-by: rankaiyx <redacted>
Co-authored-by: Stephan Walter <redacted>
Co-authored-by: DannyDaemonic <redacted>
Co-authored-by: Erik Scholz <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: David Kennedy <redacted>
Co-authored-by: Jason McCartney <redacted>
Co-authored-by: Evan Jones <redacted>
Co-authored-by: Maxime <redacted>
Co-authored-by: github-actions[bot] <redacted>
Co-authored-by: Zenix <redacted>
ggml-cuda.cu
ggml-cuda.h
ggml.c
llama-util.h
llama.cpp

index 688bcf79928d520ca37a4fa71a3c5f8a9b08dd30..35d2e457cbf8487fd967b6459925a85d1ab0282b 100644 (file)
@@ -83,9 +83,19 @@ typedef struct {
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
+#define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
 
+static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= kx) {
+        return;
+    }
+    dst[i] = x[i] * y[i%ky];
+}
+
 static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
@@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
     }
 }
 
+static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
+    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
+}
+
 static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
     }
 }
 
+static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[2];
+    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    size_t x_size, d_size;
+
+    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
+    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
+    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            const int i0 = i03*ne02 + i02;
+            float * c_X2 = d_X + i0*ne01*ne00;
+            float * c_D2 = d_D + i0*ne01*ne00;
+
+            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
+
+            // copy src0 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
+            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+            // wait for data
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                const int64_t i13 = i03%ne13;
+                const int64_t i12 = i02%ne12;
+                const int64_t i11 = i01%ne11;
+                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
+
+                float * c_X1 = c_X2 + i01*ne00;
+                float * c_Y = d_Y + i1*ne10;
+                float * c_D1 = c_D2 + i01*ne00;
+
+                // compute
+                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_D, d_size);
+}
+
 static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     ggml_cuda_pool_free(d_Q, q_size);
 }
 
+void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_mul_f32(src0, src1, dst);
+}
+
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
     const int64_t ne10 = src1->ne[0];
 
@@ -797,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
 
     size_t q_size;
-    char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
 
     cudaStream_t cudaStream2 = g_cudaStreams2[0];
 
     // copy tensor to device
-    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
-    CUDA_CHECK(cudaDeviceSynchronize());
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            int i = i3*ne2 + i2;
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
+        }
+    }
 
-    tensor->data = d_Q;
+    tensor->data = dst;
     tensor->backend = GGML_BACKEND_CUDA;
 }
+
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
+    FILE * fp = fopen(fname, "rb");
+
+    const size_t size = ggml_nbytes(tensor);
+
+    void * buf;
+    CUDA_CHECK(cudaMalloc(&buf, size));
+    void * buf_host = malloc(size);
+
+#ifdef _WIN32
+    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
+#else
+    int ret = fseek(fp, (long) offset, SEEK_SET);
+#endif
+    GGML_ASSERT(ret == 0); // same
+
+    size_t ret2 = fread(buf_host, size, 1, fp);
+    if (ret2 != 1) {
+        fprintf(stderr, "unexpectedly reached end of file");
+        exit(1);
+    }
+
+    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+    cudaDeviceSynchronize();
+
+    tensor->data = buf;
+    free(buf_host);
+    fclose(fp);
+}
index 4e2c24283ccf4b29d180d3f3f0625183f128fb88..6a04dde6c37a9d3179747575ecb2dee455a29720 100644 (file)
@@ -6,6 +6,7 @@ extern "C" {
 
 void   ggml_init_cublas(void);
 
+void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
 void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml.c b/ggml.c
index 939ab4d625ef6b9bbfa7664d698bc3c2bb235de0..d86e5942ab46e1bce7fa77d3130aefd3d7d4f65f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
         (t1->ne[3]%t0->ne[3] == 0);
 }
 
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
 static inline int ggml_up32(int n) {
     return (n + 31) & ~31;
 }
@@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -7960,7 +7970,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -7968,10 +7978,25 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
+#ifdef GGML_USE_CUBLAS
+    if (src1->backend == GGML_BACKEND_CUDA) {
+        if (ith == 0) {
+            ggml_cuda_mul(src0, src1, dst);
+        }
+        return;
+    }
+#endif
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
@@ -7990,44 +8015,51 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
 
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
 #ifdef GGML_USE_ACCELERATE
             UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul(
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
 #else
-            ggml_vec_mul_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
                 // }
             // }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
index a79c5dadd404aa81948c0b0de3aa282dce5dbb07..3cac9f681800bcd83962a603ca04bf79f69f4b9f 100644 (file)
@@ -172,7 +172,7 @@ struct llama_mmap {
 #ifdef _POSIX_MAPPED_FILES
     static constexpr bool SUPPORTED = true;
 
-    llama_mmap(struct llama_file * file, bool prefetch = true) {
+    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
         size = file->size;
         int fd = fileno(file->fp);
         int flags = MAP_SHARED;
@@ -184,9 +184,9 @@ struct llama_mmap {
             throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
         }
 
-        if (prefetch) {
+        if (prefetch > 0) {
             // Advise the kernel to preload the mapped memory
-            if (madvise(addr, file->size, MADV_WILLNEED)) {
+            if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) {
                 fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
                         strerror(errno));
             }
index 5e6980b482a6efa68420c7548607f204217b67d0..b38d55dda0e27915f412cfe1a55475f5c3b1000e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1,6 +1,7 @@
 // Defines fileno on msys:
 #ifndef _GNU_SOURCE
 #define _GNU_SOURCE
+#include <cstddef>
 #include <cstdint>
 #include <cstdio>
 #endif
@@ -645,7 +646,7 @@ struct llama_model_loader {
         }
     }
 
-    struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne) {
+    struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
         auto it = tensors_map.name_to_idx.find(name);
         if (it == tensors_map.name_to_idx.end()) {
             throw format("llama.cpp: tensor '%s' is missing from model", name.c_str());
@@ -656,10 +657,10 @@ struct llama_model_loader {
                          name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
         }
 
-        return get_tensor_for(lt);
+        return get_tensor_for(lt, backend);
     }
 
-    struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) {
+    struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
         struct ggml_tensor * tensor;
         if (lt.ne.size() == 2) {
             tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
@@ -669,6 +670,7 @@ struct llama_model_loader {
         }
         ggml_set_name(tensor, lt.name.c_str());
         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
+        tensor->backend = backend;
         lt.ggml_tensor = tensor;
         num_ggml_tensors_created++;
         return tensor;
@@ -682,12 +684,16 @@ struct llama_model_loader {
 
     void load_all_data(llama_progress_callback progress_callback, void *  progress_callback_user_data, llama_mlock * lmlock) {
         size_t data_size = 0;
+        size_t prefetch_size = 0;
         for (const llama_load_tensor & lt : tensors_map.tensors) {
             data_size += lt.size;
+            if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
+                prefetch_size += lt.size;
+            }
         }
 
         if (use_mmap) {
-            mapping.reset(new llama_mmap(&file_loaders.at(0)->file));
+            mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size));
             if (!lmlock) {
                 // Don't call the callback since the actual loading will be lazy
                 // and we can't measure it.
@@ -700,6 +706,9 @@ struct llama_model_loader {
 
         size_t done_size = 0;
         for (llama_load_tensor & lt : tensors_map.tensors) {
+            if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) {
+                continue;
+            }
             if (progress_callback) {
                 progress_callback((float) done_size / data_size, progress_callback_user_data);
             }
@@ -712,9 +721,6 @@ struct llama_model_loader {
                 lmlock->grow_to(done_size);
             }
         }
-        if (progress_callback) {
-            progress_callback(1.0f, progress_callback_user_data);
-        }
     }
 
     void load_data_for(llama_load_tensor & lt) {
@@ -969,27 +975,7 @@ static void llama_model_load_internal(
     size_t ctx_size;
     size_t mmapped_size;
     ml->calc_sizes(&ctx_size, &mmapped_size);
-    fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/1024.0/1024.0);
-
-    // print memory requirements
-    {
-        const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
-
-        // this is the total memory required to run the inference
-        const size_t mem_required =
-            ctx_size +
-            mmapped_size +
-            MEM_REQ_SCRATCH0().at(model.type) +
-            MEM_REQ_SCRATCH1().at(model.type) +
-            MEM_REQ_EVAL().at(model.type);
-
-        // this is the memory required by one llama_state
-        const size_t mem_required_state =
-            scale*MEM_REQ_KV_SELF().at(model.type);
-
-        fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
-                mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
-    }
+    fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0);
 
     // create the ggml context
     {
@@ -1011,7 +997,14 @@ static void llama_model_load_internal(
         }
     }
 
+#ifdef GGML_USE_CUBLAS
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA
+#else
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
+#endif
+
     // prepare memory for the weights
+    size_t vram_total = 0;
     {
         const uint32_t n_embd  = hparams.n_embd;
         const uint32_t n_layer = hparams.n_layer;
@@ -1019,70 +1012,122 @@ static void llama_model_load_internal(
 
         ml->ggml_ctx = ctx;
 
-        model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab});
-        model.norm           = ml->get_tensor("norm.weight",           {n_embd});
-        model.output         = ml->get_tensor("output.weight",         {n_embd, n_vocab});
+        model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
+        model.norm           = ml->get_tensor("norm.weight",           {n_embd},          GGML_BACKEND_CPU);
+
+        // "output" tensor
+        {
+            ggml_backend backend_output;
+            if (n_gpu_layers > int(n_layer)) { // NOLINT
+                backend_output = LLAMA_BACKEND_OFFLOAD;
+            } else {
+                backend_output = GGML_BACKEND_CPU;
+            }
+
+            model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
+        }
+
+        const int i_gpu_start = n_layer - n_gpu_layers;
 
         model.layers.resize(n_layer);
         for (uint32_t i = 0; i < n_layer; ++i) {
+            const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+
             auto & layer = model.layers[i];
 
             std::string layers_i = "layers." + std::to_string(i);
 
-            layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd});
+            layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
 
-            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd});
-            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd});
-            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd});
-            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd});
+            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
+            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
+            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
+            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
 
-            layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd});
+            layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
 
-            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff});
-            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd});
-            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff});
+            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff},   backend);
+            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}, backend);
+            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff},   backend);
+
+            if (backend == GGML_BACKEND_CUDA) {
+                vram_total +=
+                    ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)             +
+                    ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
+                    ggml_nbytes(layer.w1)             + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
+            }
         }
     }
 
     ml->done_getting_tensors();
 
-    // populate `tensors_by_name`
-    for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-        model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
-    }
+    // print memory requirements
+    {
+        const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
 
-    ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
+        // this is the total memory required to run the inference
+        const size_t mem_required =
+            ctx_size +
+            mmapped_size - vram_total + // weights in VRAM not in memory
+            MEM_REQ_SCRATCH0().at(model.type) +
+            MEM_REQ_SCRATCH1().at(model.type) +
+            MEM_REQ_EVAL().at(model.type);
+
+        // this is the memory required by one llama_state
+        const size_t mem_required_state =
+            scale*MEM_REQ_KV_SELF().at(model.type);
+
+        fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
+                mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
 
-    model.mapping = std::move(ml->mapping);
 #ifdef GGML_USE_CUBLAS
-    {
         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
 
         fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
+        if (n_gpu_layers > (int) hparams.n_layer) {
+            fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
+        }
+        fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
+#else
+        (void) n_gpu_layers;
+#endif
+    }
 
-        size_t vram_total = 0;
+    // populate `tensors_by_name`
+    for (llama_load_tensor & lt : ml->tensors_map.tensors) {
+        model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
+    }
 
-        for (int i = 0; i < n_gpu; ++i) {
-            const auto & layer = model.layers[i];
+    ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
 
-            ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
-            ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
-            ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
-            ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
-            ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
-            ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
-            ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
+#ifdef GGML_USE_CUBLAS
+    {
+        size_t done_size = 0;
+        size_t data_size = 0;
+        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
+            data_size += lt.size;
+            if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
+                done_size += lt.size;
+            }
         }
-        if (n_gpu_layers > (int) hparams.n_layer) {
-            fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
-            ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
+        for (llama_load_tensor & lt : ml->tensors_map.tensors) {
+            if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
+                continue;
+            }
+            if (progress_callback) {
+                progress_callback((float) done_size / data_size, progress_callback_user_data);
+            }
+            ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
+            done_size += lt.size;
         }
+    }
+#endif // GGML_USE_CUBLAS
 
-        fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
+    if (progress_callback) {
+        progress_callback(1.0f, progress_callback_user_data);
     }
-#else
-    (void) n_gpu_layers;
-#endif
+
+    model.mapping = std::move(ml->mapping);
 
     // loading time will be recalculate after the first eval, so
     // we take page faults deferred by mmap() into consideration
@@ -1181,10 +1226,8 @@ static bool llama_eval_internal(
         {
             cur = ggml_rms_norm(ctx0, inpL);
 
-            // cur = attention_norm*cur
-            cur = ggml_mul(ctx0,
-                        ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
-                        cur);
+            // cur = cur*attention_norm(broadcasted)
+            cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
         }
 
         // self-attention
@@ -1291,10 +1334,8 @@ static bool llama_eval_internal(
             {
                 cur = ggml_rms_norm(ctx0, inpFF);
 
-                // cur = ffn_norm*cur
-                cur = ggml_mul(ctx0,
-                        ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
-                        cur);
+                // cur = cur*ffn_norm(broadcasted)
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
             }
 
             struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
@@ -1331,10 +1372,8 @@ static bool llama_eval_internal(
 
         inpL = ggml_rms_norm(ctx0, inpL);
 
-        // inpL = norm*inpL
-        inpL = ggml_mul(ctx0,
-                    ggml_repeat(ctx0, model.norm, inpL),
-                    inpL);
+        // inpL = inpL*norm(broadcasted)
+        inpL = ggml_mul(ctx0, inpL, model.norm);
 
         embeddings = inpL;
     }
@@ -2158,7 +2197,7 @@ struct llama_context * llama_init_from_file(
             unsigned * cur_percentage_p = (unsigned *) ctx;
             unsigned percentage = (unsigned) (100 * progress);
             while (percentage > *cur_percentage_p) {
-                ++*cur_percentage_p;
+                *cur_percentage_p = percentage;
                 fprintf(stderr, ".");
                 fflush(stderr);
                 if (percentage >= 100) {
@@ -2315,7 +2354,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
 
         // maybe this should in llama_model_loader
         if (model_loader->use_mmap) {
-            model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false));
+            model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ 0));
         }
     }
 
@@ -2408,7 +2447,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
                 }
                 size_t idx = model_loader->tensors_map.name_to_idx[base_name];
                 llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
-                base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
+                base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
                 lt.data = (uint8_t *) lt.ggml_tensor->data;
                 model_loader->load_data_for(lt);
                 lt.ggml_tensor->data = lt.data;