]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (CUDA add/mul bcast + Metal fix + custom RoPE)
authorGeorgi Gerganov <redacted>
Sat, 15 Jul 2023 11:31:18 +0000 (14:31 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Jul 2023 11:31:18 +0000 (14:31 +0300)
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c

index b88c35bae72670d52ca6ce4ae2ca2dfe5a69e8f8..24856a255c9a6d57683b024c979f159e7f78f64e 100644 (file)
@@ -1121,6 +1121,17 @@ extern "C" {
             int                   mode,
             int                   n_ctx);
 
+    // custom RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode,
+            float                 freq_base,
+            float                 freq_scale,
+            int                   n_ctx);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
index 920466aae72db8f01c44e759a3c7c6a0b856cf1f..0646fa7b2582dcf06170b5fe2c71a71daed0e538 100644 (file)
@@ -13,6 +13,8 @@
 #include "ggml-cuda.h"
 #include "ggml.h"
 
+#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
@@ -74,7 +76,7 @@ typedef void (*ggml_cuda_op_t)(
 
 #define QK4_0 32
 #define QR4_0 2
-#define QI4_0 4
+#define QI4_0 (QK4_0 / (4 * QR4_0))
 typedef struct {
     half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
@@ -83,7 +85,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
 
 #define QK4_1 32
 #define QR4_1 2
-#define QI4_1 4
+#define QI4_1 (QK4_1 / (4 * QR4_1))
 typedef struct {
     half    d;              // delta
     half    m;              // min
@@ -93,7 +95,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
 
 #define QK5_0 32
 #define QR5_0 2
-#define QI5_0 4
+#define QI5_0 (QK5_0 / (4 * QR5_0))
 typedef struct {
     half d;                 // delta
     uint8_t qh[4];          // 5-th bit of quants
@@ -103,7 +105,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
 
 #define QK5_1 32
 #define QR5_1 2
-#define QI5_1 4
+#define QI5_1 (QK5_1 / (4 * QR5_1))
 typedef struct {
     half d;                 // delta
     half m;                 // min
@@ -114,7 +116,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 #define QR8_0 1
-#define QI8_0 8
+#define QI8_0 (QK8_0 / (4 * QR8_0))
 typedef struct {
     half    d;              // delta
     int8_t  qs[QK8_0];      // quants
@@ -123,7 +125,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 
 #define QK8_1 32
 #define QR8_1 1
-#define QI8_1 8
+#define QI8_1 (QK8_1 / (4 * QR8_1))
 typedef struct {
     half    d;              // delta
     half    s;              // unquantized sum
@@ -143,6 +145,8 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
 #define K_SCALE_SIZE 12
 #endif
 
+#define QR2_K 4
+#define QI2_K (QK_K / (4*QR2_K))
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
@@ -151,6 +155,8 @@ typedef struct {
 } block_q2_K;
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
+#define QR3_K 4
+#define QI3_K (QK_K / (4*QR3_K))
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
@@ -163,6 +169,8 @@ typedef struct {
 } block_q3_K;
 //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
 
+#define QR4_K 2
+#define QI4_K (QK_K / (4*QR4_K))
 #ifdef GGML_QKK_64
 typedef struct {
     half    d[2];              // super-block scales/mins
@@ -180,6 +188,8 @@ typedef struct {
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
 #endif
 
+#define QR5_K 2
+#define QI5_K (QK_K / (4*QR5_K))
 #ifdef GGML_QKK_64
 typedef struct {
     half d;                  // super-block scale
@@ -199,6 +209,8 @@ typedef struct {
 static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
 #endif
 
+#define QR6_K 2
+#define QI6_K (QK_K / (4*QR6_K))
 typedef struct {
     uint8_t ql[QK_K/2];   // quants, lower 4 bits
     uint8_t qh[QK_K/4];   // quants, upper 2 bits
@@ -240,13 +252,13 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
 };
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (i >= kx) {
         return;
     }
-    dst[i] = x[i] + y[i];
+    dst[i] = x[i] + y[i%ky];
 }
 
 static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -1271,8 +1283,9 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
     y[iybs + iqs + y_offset] = v.y;
 }
 
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
 
     int vi;
@@ -1293,11 +1306,12 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
 
     const int vi  = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1318,11 +1332,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
     return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
 
     int qs;
@@ -1353,11 +1368,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
 
     const int qs  = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1387,11 +1403,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
     return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
 
     int vi;
@@ -1406,7 +1423,220 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq2_K->d;
+    const float dmin = bq2_K->dmin;
+
+    const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]);
+
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = bq2_K->scales[scale_offset + 2*i];
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const float d8i = bq8i->d;
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * (sc & 0xF)); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >>  4)); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    float sumf = 0.0f;
+
+    const float d = bq3_K->d;
+
+    int vl;
+    memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
+
+    int vh;
+    memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
+    vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted
+    vh >>= bq8_offset;
+
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi = __vsubss4(vil, vih);
+
+        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    const int bq8_offset = QR4_K * (iqs / QI8_1);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq4_K->d;
+    const float dmin = bq4_K->dmin;
+
+    const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const int isc = bq8_offset + i;
+
+        uint8_t sc, m;
+        get_scale_min_k4(isc, bq4_K->scales, sc, m);
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    const int bq8_offset = QR5_K * (iqs / QI8_1);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq5_K->d;
+    const float dmin = bq5_K->dmin;
+
+    const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+
+    const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+
+    for (int i = 0; i < QR5_K; ++i) {
+        const int isc = bq8_offset + i;
+
+        uint8_t sc, m;
+        get_scale_min_k4(isc, bq5_K->scales, sc, m);
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> i) << 4) & 0x10101010;
+
+        const int vi = vil | vih;
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q5_K with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+    float sumf = 0.0f;
+
+    const float d = bq6_K->d;
+
+    int vl;
+    memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int));
+
+    int vh;
+    memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int));
+
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = bq6_K->scales[scale_offset + 4*i];
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030;
+
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -1429,7 +1659,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
         const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
 
-        const int iby = i + threadIdx.x / qi; // y block index
+        const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx
 
         const int iqs  = threadIdx.x % qi; // x block quant index when casting the quants to int
 
@@ -1766,9 +1996,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
 static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -1962,7 +2192,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 }
 
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1971,7 +2201,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1980,7 +2210,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1989,7 +2219,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1998,7 +2228,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -2006,6 +2236,51 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
+static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, vec_dot_q2_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -2335,17 +2610,15 @@ inline void ggml_cuda_op_add(
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i  != nullptr);
 
-    // TODO: support broadcasting
-    GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
-
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
-    // const int64_t ne10 = src1->ne[0];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
 
     // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
         add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
     } else {
@@ -2369,19 +2642,12 @@ inline void ggml_cuda_op_mul(
     GGML_ASSERT(dst_ddf_i  != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
-        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
-
-        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
-        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
-        float * dst_ddf_i01  = dst_ddf_i + i01*ne00;
-
-        // compute
-        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-    }
+    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2494,13 +2760,22 @@ inline void ggml_cuda_op_mul_mat_vec(
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
-    const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 ||
+    bool mul_mat_vec_q_implemented =
+        src0->type == GGML_TYPE_Q4_0 ||
         src0->type == GGML_TYPE_Q4_1 ||
         src0->type == GGML_TYPE_Q5_0 ||
         src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0;
-
-    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
+#if QK_K == 256
+    mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
+        src0->type == GGML_TYPE_Q2_K ||
+        src0->type == GGML_TYPE_Q3_K ||
+        src0->type == GGML_TYPE_Q4_K ||
+        src0->type == GGML_TYPE_Q5_K ||
+        src0->type == GGML_TYPE_Q6_K;
+#endif // QK_K == 256
+
+    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
 #endif
 
     if (use_mul_mat_vec_q) {
@@ -2526,6 +2801,21 @@ inline void ggml_cuda_op_mul_mat_vec(
             case GGML_TYPE_Q8_0:
                 mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
                 break;
+            case GGML_TYPE_Q2_K:
+                mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q3_K:
+                mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q4_K:
+                mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_K:
+                mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q6_K:
+                mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
             default:
                 GGML_ASSERT(false);
                 break;
@@ -3356,6 +3646,22 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     delete extra;
 }
 
+static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
+static size_t g_temp_tensor_extra_index = 0;
+
+static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
+    if (g_temp_tensor_extras == nullptr) {
+        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+    }
+
+    size_t alloc_index = g_temp_tensor_extra_index;
+    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+    struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
+    memset(extra, 0, sizeof(*extra));
+
+    return extra;
+}
+
 void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
     if (scratch && g_scratch_size == 0) {
         return;
@@ -3373,8 +3679,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     }
 
     tensor->backend = GGML_BACKEND_GPU;
-    struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
-    memset(extra, 0, sizeof(*extra));
+    struct ggml_tensor_extra_gpu * extra;
 
     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
         tensor->op == GGML_OP_VIEW ||
@@ -3389,10 +3694,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         if (tensor->op == GGML_OP_VIEW) {
             memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
         }
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src0_ddc + offset;
     } else if (tensor->op == GGML_OP_CPY) {
         struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
         void * src1_ddv = src1_extra->data_device[g_main_device];
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src1_ddv;
     } else if (scratch) {
         GGML_ASSERT(size <= g_scratch_size);
@@ -3405,6 +3712,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
             CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
             g_scratch_buffer = data;
         }
+        extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = data + g_scratch_offset;
 
         g_scratch_offset += size;
@@ -3414,6 +3722,8 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         void * data;
         CUDA_CHECK(cudaMalloc(&data, size));
         CUDA_CHECK(cudaMemset(data, 0, size));
+        extra = new ggml_tensor_extra_gpu;
+        memset(extra, 0, sizeof(*extra));
         extra->data_device[g_main_device] = data;
     }
 
index c795ee22784bda5b7b0dfd0e0631b16ba398f1fb..ee205bcdf773ca1666016ac63aae16a7aafdfb96 100644 (file)
@@ -881,28 +881,35 @@ void ggml_metal_graph_compute(
 
                             const int n_past = ((int32_t *)(src1->data))[0];
 
+                            float freq_base;
+                            float freq_scale;
+                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
-                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:21];
+                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
index f094a1d407f08b21f6493c105218e891a3cc3f37..9f9a4fbd74446e2c203fe2f8c65653ee6f5360f7 100644 (file)
@@ -656,17 +656,19 @@ kernel void kernel_rope(
         constant       int & n_past,
         constant       int & n_dims,
         constant       int & mode,
+        constant     float & freq_base,
+        constant     float & freq_scale,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i3 = tpig[2];
     const int64_t i2 = tpig[1];
     const int64_t i1 = tpig[0];
 
     const bool is_neox = mode & 2;
-    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+    const float theta_scale = pow(freq_base, -2.0f/n_dims);
 
     const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
 
-    float theta = (float)p;
+    float theta = freq_scale * (float)p;
 
     if (!is_neox) {
         for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
index f5821f1f1aa042d3c04eabc36b9aca9678ea4e87..5ce1da0e9df4dc503903efacf926b35497ffee70 100644 (file)
 #include <unistd.h>
 #endif
 
+// static_assert should be a #define, but if it's not,
+// fall back to the _Static_assert C11 keyword.
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
 #ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
 #define static_assert(cond, msg) struct global_scope_noop_trick
 #endif
+#endif
 
 #if defined(_MSC_VER)
 // disable "possible loss of data" to avoid hundreds of casts
@@ -112,10 +118,6 @@ typedef void * thread_ret_t;
 #endif
 #endif
 
-#ifdef __HAIKU__
-#define static_assert(cond, msg) _Static_assert(cond, msg)
-#endif
-
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
@@ -6954,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        float                 freq_base,
+        float                 freq_scale,
         int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
@@ -6967,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
     ((int32_t *) b->data)[3] = n_ctx;
+    memcpy((int32_t *) b->data + 4, &freq_base,  sizeof(float));
+    memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
 
     ggml_scratch_load(ctx);
 
@@ -6991,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -7001,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode,
+        float                 freq_base,
+        float                 freq_scale,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
 }
 
 // ggml_rope_back
@@ -12072,16 +12090,21 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 4);
+    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float freq_base;
+    float freq_scale;
+
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
     const int n_ctx  = ((int32_t *) src1->data)[3];
+    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12110,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -12122,7 +12145,7 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (is_glm) {
                     theta = MIN(p, n_ctx - 2);
@@ -12199,16 +12222,21 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 4);
+    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float freq_base;
+    float freq_scale;
+
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
     const int n_ctx  = ((int32_t *) src1->data)[3];
+    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12237,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -12249,7 +12277,7 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (is_glm) {
                     theta = MIN(p, n_ctx - 2);
@@ -12310,7 +12338,7 @@ static void ggml_compute_forward_rope_f16(
                             const float x0 = GGML_FP16_TO_FP32(src[0]);
                             const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
 
-                            dst_data[0]     = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                             dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                         }
                     }
@@ -15708,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 4);
+                    assert(ggml_nelements(src1) == 6);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
@@ -15729,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 4);
+                    assert(ggml_nelements(src1) == 3);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];