]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (CUDA ReLU, CPU-only with CUDA, bloom fix, etc) (#607)
authorGeorgi Gerganov <redacted>
Mon, 13 Nov 2023 14:54:34 +0000 (16:54 +0200)
committerGitHub <redacted>
Mon, 13 Nov 2023 14:54:34 +0000 (16:54 +0200)
ggml-ci

include/ggml/ggml.h
scripts/sync-llama.sh
src/ggml-alloc.c
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml.c

index 52ae6755a89ad2e8746642233fea80a1d149937a..8e6b646066b7a488197becae814d17e504916194 100644 (file)
@@ -1371,8 +1371,13 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow,
             float                 xpos_base,
             bool                  xpos_down);
 
index b8c649ba058e3e41271ebebc0ae6aa39b01260c9..6a04b3f270a9858d3d1d463fadec3512290b6dbd 100755 (executable)
@@ -2,7 +2,7 @@
 
 cp -rpv ../llama.cpp/ggml.c              src/ggml.c
 cp -rpv ../llama.cpp/ggml-alloc.c        src/ggml-alloc.c
-cp -rpv ../llama.cpp/ggml-backend-impl.c src/ggml-backend-impl.c
+cp -rpv ../llama.cpp/ggml-backend-impl.h src/ggml-backend-impl.h
 cp -rpv ../llama.cpp/ggml-backend.c      src/ggml-backend.c
 cp -rpv ../llama.cpp/ggml-cuda.cu        src/ggml-cuda.cu
 cp -rpv ../llama.cpp/ggml-cuda.h         src/ggml-cuda.h
index f400dc56a20d054ee7d8879447bec42c2732d59e..cdfe4caf69613d6778f8a80af9b0d6ffc8e67652 100644 (file)
@@ -446,12 +446,14 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
     return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
 }
 
-static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view) {
+static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
     ggml_tallocr_t alloc = node_tallocr(galloc, view);
 
     //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
     GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
-    view->backend = view->view_src->backend;
+    if (update_backend) {
+        view->backend = view->view_src->backend;
+    }
     view->buffer  = view->view_src->buffer;
     view->data    = (char *)view->view_src->data + view->view_offs;
 
@@ -469,7 +471,7 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
 
     if (node->data == NULL) {
         if (ggml_is_view(node)) {
-            init_view(galloc, node);
+            init_view(galloc, node, true);
         } else {
             // see if we can reuse a parent's buffer (inplace)
             if (ggml_op_can_inplace(node->op)) {
@@ -499,15 +501,14 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
                                 AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
                                 node->view_src = view_src;
                                 view_src_hn->n_views += 1;
-                                init_view(galloc, node);
+                                init_view(galloc, node, false);
                                 return;
                             }
-                        }
-                        else {
+                        } else {
                             AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
                             node->view_src = parent;
                             p_hn->n_views += 1;
-                            init_view(galloc, node);
+                            init_view(galloc, node, false);
                             return;
                         }
                     }
@@ -537,7 +538,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
             hash_get(galloc, view_src)->n_views += 1;
             if (node->buffer == NULL && node->data != NULL) {
                 // view of a pre-allocated tensor, didn't call init_view() yet
-                init_view(galloc, node);
+                init_view(galloc, node, true);
             }
         }
 
@@ -548,7 +549,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
             }
             hash_get(galloc, parent)->n_children += 1;
             if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
-                init_view(galloc, parent);
+                init_view(galloc, parent, true);
             }
         }
    }
@@ -663,7 +664,7 @@ size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, st
     return max_size;
 }
 
-void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_alloct) {
+void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
     const size_t hash_size = hash_set.size;
 
     GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
@@ -686,7 +687,7 @@ void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * grap
     // reset hash values
     memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
 
-    galloc->hash_allocs = hash_node_alloct;
+    galloc->hash_allocs = hash_node_talloc;
 
     ggml_tallocr_alloc_graph_impl(galloc, graph);
 
index 058011a48f0da5189223145aab5c4aa26d1b3b3f..7be63925f4edadfb9433a9f73921a8544d378309 100644 (file)
@@ -39,7 +39,6 @@
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
-#define cudaDeviceGetMemPool hipDeviceGetMemPool
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
@@ -49,7 +48,6 @@
 #define cudaEvent_t hipEvent_t
 #define cudaEventDestroy hipEventDestroy
 #define cudaFree hipFree
-#define cudaFreeAsync hipFreeAsync
 #define cudaFreeHost hipHostFree
 #define cudaGetDevice hipGetDevice
 #define cudaGetDeviceCount hipGetDeviceCount
@@ -57,7 +55,6 @@
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
 #define cudaMalloc hipMalloc
-#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -66,9 +63,6 @@
 #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
 #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
 #define cudaMemcpyKind hipMemcpyKind
-#define cudaMemPool_t hipMemPool_t
-#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
-#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
 #define cudaMemset hipMemset
 #define cudaMemsetAsync hipMemsetAsync
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
@@ -188,11 +182,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cudaError_t err_ = (err);                                                       \
         if (err_ != cudaSuccess) {                                                      \
-            int dev_id;                                                                     \
-            cudaGetDevice(&dev_id);                                                         \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
-            fprintf(stderr, "current device: %d\n", dev_id);                                \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -202,11 +196,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int dev_id;                                                                     \
-            cudaGetDevice(&dev_id);                                                         \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
-            fprintf(stderr, "current device: %d\n", dev_id);                                \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -440,6 +434,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
@@ -472,7 +468,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 
 #define MAX_STREAMS 8
 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
-static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -561,6 +556,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void relu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0);
+}
+
+static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
+}
+
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -990,7 +1003,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1094,7 +1107,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1198,7 +1211,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
@@ -1452,7 +1465,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -4262,7 +4275,7 @@ template <bool need_check> static __global__ void
 
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4302,7 +4315,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
     // qr = number of quantized weights per data value in x block
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4741,7 +4754,7 @@ static  __global__ void im2col_f32_f16(
         int ofs0, int ofs1, int IW, int IH, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
     const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-       const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
 
     const int offset_dst =
         (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
@@ -4793,6 +4806,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
+    sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -4901,7 +4924,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4910,7 +4934,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4919,7 +4943,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4928,7 +4952,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4937,7 +4961,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4947,7 +4971,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4956,7 +4980,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4965,7 +4989,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4980,7 +5004,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4988,7 +5012,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4997,7 +5021,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5006,7 +5030,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5015,7 +5039,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5024,7 +5048,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5033,7 +5057,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5042,7 +5066,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5051,7 +5075,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5060,7 +5084,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5069,7 +5093,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5088,7 +5112,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<1, 1, convert_f16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -5825,16 +5849,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     return ptr;
 }
 
-static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
-    if (g_cudaMemPools[id] == nullptr) {
-        return ggml_cuda_pool_malloc(size, actual_size);
-    }
-    void *ptr;
-    CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
-    *actual_size = size;
-    return ptr;
-}
-
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
@@ -5852,12 +5866,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
+static bool g_cublas_loaded = false;
 
-static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
-    if (g_cudaMemPools[id] == nullptr) {
-        return ggml_cuda_pool_free(ptr, actual_size);
-    }
-    CUDA_CHECK(cudaFreeAsync(ptr, stream));
+bool ggml_cublas_loaded(void) {
+    return g_cublas_loaded;
 }
 
 void ggml_init_cublas() {
@@ -5872,7 +5884,12 @@ void ggml_init_cublas() {
         CUDA_CHECK(cudaDeviceSynchronize());
 #endif
 
-        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
+            initialized = true;
+            g_cublas_loaded = false;
+            return;
+        }
+
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
 #if defined(GGML_CUDA_FORCE_MMQ)
@@ -5914,19 +5931,13 @@ void ggml_init_cublas() {
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
-
-            // configure memory pool
-            cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
-            if (err == cudaSuccess) {
-                size_t treshold = UINT64_MAX;
-                CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
-            }
         }
 
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
 
         initialized = true;
+        g_cublas_loaded = true;
     }
 }
 
@@ -6193,6 +6204,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_sqr(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6514,7 +6553,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = row_diff*ne00;
-            src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6525,12 +6564,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = src1_ncols*ne10;
-            src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
+            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
-        size_t dst_f16_as = 0;
-        half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
+        size_t dst_as = 0;
+        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -6548,15 +6587,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
 
-        if (dst_f16_as != 0) {
-            ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
-        }
+        ggml_cuda_pool_free(dst_f16, dst_as);
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
+            ggml_cuda_pool_free(src0_as_f16, src0_as);
         }
+
         if (src1_as != 0) {
-            ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
+            ggml_cuda_pool_free(src1_as_f16, src1_as);
         }
     }
     else {
@@ -6566,7 +6604,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
             GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
+            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
         }
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6583,7 +6621,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
                     &beta,  dst_dd_i,   ldc));
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
+            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
         }
     }
 
@@ -7008,6 +7046,8 @@ static void ggml_cuda_op_mul_mat(
     int64_t  row_low[GGML_CUDA_MAX_DEVICES];
     int64_t row_high[GGML_CUDA_MAX_DEVICES];
 
+    int used_devices = 0;
+
     for (int64_t id = 0; id < g_device_count; ++id) {
         // by default, use all rows
         row_low[id]  = 0;
@@ -7035,6 +7075,8 @@ static void ggml_cuda_op_mul_mat(
             continue;
         }
 
+        used_devices++;
+
         const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
         const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
@@ -7045,22 +7087,21 @@ static void ggml_cuda_op_mul_mat(
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
-            src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
         if (src1_on_device && src1_is_contiguous) {
             src1_ddf[id] = (float *) src1_extra->data_device[id];
         } else {
-            src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
         }
 
         if (convert_src1_to_q8_1) {
-            const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
-            src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
 
             if (src1_on_device && src1_is_contiguous) {
                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
-                // CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaGetLastError());
             }
         }
 
@@ -7068,18 +7109,18 @@ static void ggml_cuda_op_mul_mat(
             dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
-            dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id,  stream);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
         }
     }
 
     // if multiple devices are used they need to wait for the main device
     // here an event is recorded that signals that the main device has finished calculating the input data
-    if (split && g_device_count > 1) {
+    if (split && used_devices > 1) {
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
     }
 
-    const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
     for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
         const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
         const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7194,6 +7235,27 @@ static void ggml_cuda_op_mul_mat(
         }
     }
 
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+            continue;
+        }
+        CUDA_CHECK(ggml_cuda_set_device(id));
+
+        // free buffers again when done
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
+        }
+    }
+
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7201,6 +7263,9 @@ static void ggml_cuda_op_mul_mat(
 
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         for (int64_t id = 0; id < g_device_count; ++id) {
+            if (row_low[id] == row_high[id]) {
+                continue;
+            }
             for (int64_t is = 0; is < is_max; ++is) {
                 CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
             }
@@ -7211,21 +7276,6 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
-
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
-        }
-        if (src1_asq[id] > 0) {
-            ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
-        }
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
-        }
-    }
 }
 
 static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7252,6 +7302,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
+}
+
+static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
+}
+
 static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
@@ -7261,6 +7319,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (!g_cublas_loaded) return false;
+
     const int64_t ne10 = src1->ne[0];
 
     const int64_t ne0 = dst->ne[0];
@@ -7412,11 +7472,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     GGML_ASSERT(to_fp16_cuda != nullptr);
 
     size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
 
     size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -7470,8 +7530,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         size_t ptrs_src_s = 0;
         size_t ptrs_dst_s = 0;
 
-        ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
-        ptrs_dst = (      void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
+        ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
+        ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7484,6 +7544,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 dst->nb[2], dst->nb[3],
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
+
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
@@ -7495,30 +7556,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
         if (ptrs_src_s != 0) {
-            ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
+            ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
         }
         if (ptrs_dst_s != 0) {
-            ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
+            ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
         }
     }
 #endif
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-    if (src1_as != 0) {
-        ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
-    }
-    if (dst_as != 0) {
-        ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
-    }
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool all_on_device =
-        (src0->backend == GGML_BACKEND_GPU) &&
+        (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         (src1->backend == GGML_BACKEND_GPU) &&
         ( dst->backend == GGML_BACKEND_GPU);
 
+    const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+
     int64_t min_compute_capability = INT_MAX;
     for (int64_t id = 0; id < g_device_count; ++id) {
         if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
@@ -7540,13 +7600,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+    if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         // KQ single-batch
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+    } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
-    } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+    } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
         // KQ + KQV multi-batch
         ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
@@ -7953,6 +8013,8 @@ void ggml_cuda_free_scratch() {
 }
 
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    if (!g_cublas_loaded) return false;
+
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
@@ -7995,6 +8057,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_RELU:
+                    func = ggml_cuda_relu;
+                    break;
                 default:
                     return false;
             } break;
@@ -8013,6 +8078,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_SQR:
+            func = ggml_cuda_sqr;
+            break;
         case GGML_OP_CLAMP:
             if (!any_on_device) {
                 return false;
index 57adc9cf34bc5bc4fae4d576b4c2b1572364b8ff..528e66c33a20738ce185744ff5780203c869e4ad 100644 (file)
@@ -17,7 +17,12 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
+// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
 GGML_API void   ggml_init_cublas(void);
+
+// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
+GGML_API bool   ggml_cublas_loaded(void);
+
 GGML_API void * ggml_cuda_host_malloc(size_t size);
 GGML_API void   ggml_cuda_host_free(void * ptr);
 
index 584ee4680378d9ddb3722d61e618d9837a07a88d..3202a517b78686da7ab1721781899cd09af05356 100644 (file)
@@ -5024,8 +5024,13 @@ struct ggml_tensor * ggml_rope_back(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
         float                 xpos_base,
         bool                  xpos_down) {
     GGML_ASSERT(ggml_is_vector(b));
@@ -5042,11 +5047,15 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base,  sizeof(float));
-    memcpy(params + 5, &freq_scale, sizeof(float));
-    memcpy(params + 6, &xpos_base,  sizeof(float));
-    memcpy(params + 7, &xpos_down,  sizeof(bool));
+    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &xpos_base,    sizeof(float));
+    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -9376,7 +9385,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
-
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10946,7 +10954,8 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+        struct ggml_tensor * dst,
+        const bool forward) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -11005,6 +11014,11 @@ static void ggml_compute_forward_rope_f32(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
     const int32_t * pos = (const int32_t *) src1->data;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11021,9 +11035,9 @@ static void ggml_compute_forward_rope_f32(
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                         const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
+                        const float sin_theta = sinf(theta_base) * sin_sign;
                         const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta);
+                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 
                         theta_base *= theta_scale;
                         block_theta *= theta_scale;
@@ -11047,6 +11061,7 @@ static void ggml_compute_forward_rope_f32(
                         rope_yarn(
                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
                         );
+                        sin_theta *= sin_sign;
 
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@@ -11077,6 +11092,7 @@ static void ggml_compute_forward_rope_f32(
                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
+                            sin_theta *= sin_sign;
 
                             theta_base *= theta_scale;
 
@@ -11102,7 +11118,8 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+        struct ggml_tensor * dst,
+        const bool forward) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -11154,6 +11171,11 @@ static void ggml_compute_forward_rope_f16(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
     const int32_t * pos = (const int32_t *) src1->data;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11170,9 +11192,9 @@ static void ggml_compute_forward_rope_f16(
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                         const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
+                        const float sin_theta = sinf(theta_base) * sin_sign;
                         const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta);
+                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 
                         theta_base *= theta_scale;
                         block_theta *= theta_scale;
@@ -11196,6 +11218,7 @@ static void ggml_compute_forward_rope_f16(
                         rope_yarn(
                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
                         );
+                        sin_theta *= sin_sign;
 
                         theta_base *= theta_scale;
 
@@ -11222,6 +11245,7 @@ static void ggml_compute_forward_rope_f16(
                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
+                            sin_theta *= sin_sign;
 
                             theta_base *= theta_scale;
 
@@ -11251,11 +11275,11 @@ static void ggml_compute_forward_rope(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
             } break;
         default:
             {
@@ -11266,216 +11290,6 @@ static void ggml_compute_forward_rope(
 
 // ggml_compute_forward_rope_back
 
-static void ggml_compute_forward_rope_back_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // y = rope(x, src1)
-    // dx = rope_back(dy, src1)
-    // src0 is dy, src1 contains options
-
-    float freq_base;
-    float freq_scale;
-
-    // these two only relevant for xPos RoPE:
-    float xpos_base;
-    bool xpos_down;
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    assert(nb0 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    const bool is_neox = mode & 2;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                float theta_base = freq_scale * (float)p;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
-
-                        // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
-                        if (xpos_down) zeta = 1.0f / zeta;
-
-                        theta_base *= theta_scale;
-
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float dy0 = dy[0];
-                        const float dy1 = dy[1];
-
-                        dx[0] =   dy0*cos_theta*zeta + dy1*sin_theta*zeta;
-                        dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
-                    }
-                } else {
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta_base);
-                            const float sin_theta = sinf(theta_base);
-
-                            theta_base *= theta_scale;
-
-                            const int64_t i0 = ib*n_dims + ic/2;
-
-                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                            const float dy0 = dy[0];
-                            const float dy1 = dy[n_dims/2];
-
-                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
-                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rope_back_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // y = rope(x, src1)
-    // dx = rope_back(dy, src1)
-    // src0 is dy, src1 contains options
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    assert(nb0 == sizeof(ggml_fp16_t));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
-
-    const bool is_neox = mode & 2;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                float theta_base = (float)p;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base);
-
-                        theta_base *= theta_scale;
-
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                        const float dy1 = GGML_FP16_TO_FP32(dy[1]);
-
-                        dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                        dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                    }
-                } else {
-                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta_base);
-                            const float sin_theta = sinf(theta_base);
-
-                            theta_base *= theta_scale;
-
-                            const int64_t i0 = ib*n_dims + ic/2;
-
-                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
-
-                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
 static void ggml_compute_forward_rope_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -11484,11 +11298,11 @@ static void ggml_compute_forward_rope_back(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
             } break;
         default:
             {
@@ -14923,17 +14737,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims = ((int32_t *) tensor->op_params)[1];
-                    const int mode   = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
-                    float freq_base;
-                    float freq_scale;
-                    float xpos_base;
-                    bool  xpos_down;
-                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
+                    const int mode       = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+
+                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
+                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
+                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
+                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
+                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
+                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
+                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
+                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -14943,8 +14760,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_dims,
                                 mode,
                                 n_ctx,
+                                n_orig_ctx,
                                 freq_base,
                                 freq_scale,
+                                ext_factor,
+                                attn_factor,
+                                beta_fast,
+                                beta_slow,
                                 xpos_base,
                                 xpos_down),
                             zero_table);
@@ -14954,17 +14776,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims = ((int32_t *) tensor->op_params)[1];
-                    const int mode   = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
-                    float freq_base;
-                    float freq_scale;
-                    float xpos_base;
-                    bool  xpos_down;
-                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
+                    const int mode       = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+
+                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
+                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
+                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
+                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
+                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
+                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
+                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
+                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -14973,14 +14798,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src1,
                                 n_dims,
                                 mode,
-                                0,
                                 n_ctx,
+                                n_orig_ctx,
                                 freq_base,
                                 freq_scale,
-                                0.0f,
-                                1.0f,
-                                0.0f,
-                                0.0f,
+                                ext_factor,
+                                attn_factor,
+                                beta_fast,
+                                beta_slow,
                                 xpos_base,
                                 xpos_down,
                                 false),