]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : sync latest (SAM + SD operators, CUDA alibi) (#2709)
authorGeorgi Gerganov <redacted>
Tue, 22 Aug 2023 11:22:08 +0000 (14:22 +0300)
committerGitHub <redacted>
Tue, 22 Aug 2023 11:22:08 +0000 (14:22 +0300)
* ggml : sync latest (SAM + SD operators, CUDA alibi)

ggml-ci

* ggml : fix tabs

examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-alloc.c
ggml-cuda.cu
ggml.c
ggml.h
scripts/sync-ggml.sh

index 31d6620a235013b3d1ec73dc97a3e8089ae1fe1d..79b117df72fd33efe5ccabdc8f8e7c93f4894f87 100644 (file)
@@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1));                                            assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
         t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd));                 assert_shape_2d(t11->grad, N*n_batch, n_embd);
         t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3));                                            assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
-        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
+        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false));        assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
         t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch));                                  assert_shape_2d(t08->grad, n_embd, N*n_batch);
         t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3));                                            assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
-        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
+        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false));        assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
         t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch));                                  assert_shape_2d(t05->grad, n_embd, N*n_batch);
         t04->grad = expand(gb, ggml_add_inplace(ctx0,
                         ggml_add_inplace(ctx0,
index 3ee98d03dea4d0074944e67f67c11c74fc71cf67..f06f9a3c1d97b97c700631fe810f62fec280310a 100644 (file)
@@ -76,7 +76,7 @@ struct ggml_allocr {
 };
 
 #ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == NULL) {
             alloc->allocated_tensors[i] = tensor;
@@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens
     }
     GGML_ASSERT(!"out of allocated_tensors");
 }
-static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == tensor ||
             (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
index 5b415c646e8c6fc87c6522bfb85a48f9573baa6b..c0fb9fb650e0d22b04a444ef9056b8129f0ac392 100644 (file)
@@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
@@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
 }
 
+static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
+                                 const int n_heads_log2_floor, const float m0, const float m1) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const int k = row/k_rows;
+
+    float m_k;
+    if (k < n_heads_log2_floor) {
+        m_k = powf(m0, k + 1);
+    } else {
+        m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+    }
+
+    dst[i] = col * m_k + x[i];
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
 }
 
+static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
+                           const int k_rows, const int n_heads_log2_floor, const float m0,
+                           const float m1, cudaStream_t stream) {
+    const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
     (void) i1;
 }
 
+inline void ggml_cuda_op_alibi(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
+
+    GGML_ASSERT(ne01 + n_past == ne00);
+    GGML_ASSERT(n_head == ne02);
+
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+    // compute
+    alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+
+    (void) src1;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
+void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+}
+
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_rope;
             break;
+        case GGML_OP_ALIBI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_alibi;
+            break;
         default:
             return false;
     }
diff --git a/ggml.c b/ggml.c
index c917d73c7e0d4ba4ab09fcf700ea0c58019802fe..dffb977313584e5efd8b75d34e29e3bbe61724a1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -216,7 +216,6 @@ inline static void * ggml_aligned_malloc(size_t size) {
         GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
         return NULL;
     }
-
     return aligned_memory;
 }
 #define GGML_ALIGNED_MALLOC(size)  ggml_aligned_malloc(size)
@@ -3722,6 +3721,10 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
     *s = idx;
 }
 
+//
+// data types
+//
+
 static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "NONE",
 
@@ -3741,10 +3744,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARGMAX",
     "REPEAT",
     "REPEAT_BACK",
+    "CONCAT",
     "SILU_BACK",
     "NORM",
     "RMS_NORM",
     "RMS_NORM_BACK",
+    "GROUP_NORM",
 
     "MUL_MAT",
     "OUT_PROD",
@@ -3770,20 +3775,28 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CLAMP",
     "CONV_1D",
     "CONV_2D",
+    "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
+    "UPSCALE",
 
     "FLASH_ATTN",
     "FLASH_FF",
     "FLASH_ATTN_BACK",
     "WIN_PART",
     "WIN_UNPART",
+    "GET_REL_POS",
+    "ADD_REL_POS",
 
     "UNARY",
 
     "MAP_UNARY",
     "MAP_BINARY",
 
+    "MAP_CUSTOM1_F32",
+    "MAP_CUSTOM2_F32",
+    "MAP_CUSTOM3_F32",
+
     "MAP_CUSTOM1",
     "MAP_CUSTOM2",
     "MAP_CUSTOM3",
@@ -3792,7 +3805,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 62, "GGML_OP_COUNT != 62");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3813,10 +3826,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "argmax(x)",
     "repeat(x)",
     "repeat_back(x)",
+    "concat(x, y)",
     "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
     "rms_norm_back(x)",
+    "group_norm(x)",
 
     "X*Y",
     "X*Y",
@@ -3842,20 +3857,28 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "clamp(x)",
     "conv_1d(x)",
     "conv_2d(x)",
+    "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
+    "upscale(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
     "flash_attn_back(x)",
     "win_part(x)",
     "win_unpart(x)",
+    "get_rel_pos(x)",
+    "add_rel_pos(x)",
 
     "unary(x)",
 
     "f(x)",
     "f(x,y)",
 
+    "custom_f32(x)",
+    "custom_f32(x,y)",
+    "custom_f32(x,y,z)",
+
     "custom(x)",
     "custom(x,y)",
     "custom(x,y,z)",
@@ -3864,7 +3887,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 62, "GGML_OP_COUNT != 62");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -3894,8 +3917,10 @@ static void ggml_setup_op_has_task_pass(void) {
         p[GGML_OP_DIAG_MASK_ZERO         ] = true;
         p[GGML_OP_CONV_1D                ] = true;
         p[GGML_OP_CONV_2D                ] = true;
+        p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
         p[GGML_OP_FLASH_ATTN_BACK        ] = true;
         p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
+        p[GGML_OP_ADD_REL_POS            ] = true;
     }
 
     {   // FINALIZE
@@ -5572,6 +5597,30 @@ struct ggml_tensor * ggml_repeat_back(
     return result;
 }
 
+// ggml_concat
+
+struct ggml_tensor* ggml_concat(
+    struct ggml_context* ctx,
+    struct ggml_tensor* a,
+    struct ggml_tensor* b) {
+    GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
+
+    result->op = GGML_OP_CONCAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // ggml_abs
 
 struct ggml_tensor * ggml_abs(
@@ -5771,6 +5820,8 @@ struct ggml_tensor * ggml_norm_inplace(
     return ggml_norm_impl(ctx, a, true);
 }
 
+// ggml_rms_norm
+
 static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5807,6 +5858,8 @@ struct ggml_tensor * ggml_rms_norm_inplace(
     return ggml_rms_norm_impl(ctx, a, eps, true);
 }
 
+// ggml_rms_norm_back
+
 struct ggml_tensor * ggml_rms_norm_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5828,6 +5881,44 @@ struct ggml_tensor * ggml_rms_norm_back(
     return result;
 }
 
+// ggml_group_norm
+
+static struct ggml_tensor * ggml_group_norm_impl(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups,
+    bool inplace) {
+
+    bool is_node = false;
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op = GGML_OP_GROUP_NORM;
+    result->op_params[0] = n_groups;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_group_norm(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups) {
+    return ggml_group_norm_impl(ctx, a, n_groups, false);
+}
+
+struct ggml_tensor * ggml_group_norm_inplace(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups) {
+    return ggml_group_norm_impl(ctx, a, n_groups, true);
+}
 
 // ggml_mul_mat
 
@@ -6696,6 +6787,8 @@ static struct ggml_tensor * ggml_rope_impl(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 xpos_base,
+        bool                  xpos_down,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6706,9 +6799,11 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[6] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
     memcpy(params + 4, &freq_base,  sizeof(float));
     memcpy(params + 5, &freq_scale, sizeof(float));
+    memcpy(params + 6, &xpos_base,  sizeof(float));
+    memcpy(params + 7, &xpos_down,  sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -6725,7 +6820,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -6735,7 +6830,7 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
 }
 
 struct ggml_tensor * ggml_rope_custom(
@@ -6747,7 +6842,7 @@ struct ggml_tensor * ggml_rope_custom(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -6759,7 +6854,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+}
+
+struct ggml_tensor * ggml_rope_xpos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        float                 base,
+        bool                  down) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
 }
 
 // ggml_rope_back
@@ -6770,7 +6875,11 @@ struct ggml_tensor * ggml_rope_back(
         int                   n_past,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx) {
+        int                   n_ctx,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 xpos_base,
+        bool                  xpos_down) {
     GGML_ASSERT(n_past >= 0);
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
@@ -6782,7 +6891,11 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
+    memcpy(params + 4, &freq_base,  sizeof(float));
+    memcpy(params + 5, &freq_scale, sizeof(float));
+    memcpy(params + 6, &xpos_base,  sizeof(float));
+    memcpy(params + 7, &xpos_down,  sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -6889,6 +7002,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
     return result;
 }
 
+// ggml_conv_1d_ph
+
+struct ggml_tensor* ggml_conv_1d_ph(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s,
+        int                   d) {
+    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
+}
+
 // ggml_conv_2d
 
 struct ggml_tensor * ggml_conv_2d(
@@ -6929,17 +7053,59 @@ struct ggml_tensor * ggml_conv_2d(
 
 }
 
-// ggml_conv_1d_ph
+// ggml_conv_2d_sk_p0
 
-struct ggml_tensor * ggml_conv_1d_ph(
+struct ggml_tensor * ggml_conv_2d_sk_p0(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int                   s,
-        int                   d) {
-    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
+        struct ggml_tensor  * b) {
+    return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
+}
+
+// ggml_conv_2d_s1_ph
+
+struct ggml_tensor * ggml_conv_2d_s1_ph(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
 }
 
+// ggml_conv_transpose_2d_p0
+
+static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
+    return (ins - 1) * s - 2 * p + ks;
+}
+
+struct ggml_tensor * ggml_conv_transpose_2d_p0(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   stride) {
+    GGML_ASSERT(a->ne[3] == b->ne[2]);
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int64_t ne[4] = {
+        ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
+        ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
+        a->ne[2], b->ne[3],
+    };
+
+    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    result->op = GGML_OP_CONV_TRANSPOSE_2D;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = ggml_new_i32(ctx, stride);
+
+    return result;
+}
 
 // ggml_pool_*
 
@@ -7017,6 +7183,40 @@ struct ggml_tensor * ggml_pool_2d(
     return result;
 }
 
+// ggml_upscale
+
+static struct ggml_tensor * ggml_upscale_impl(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int scale_factor) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] * scale_factor,
+            a->ne[1] * scale_factor,
+            a->ne[2], a->ne[3]);
+
+    result->op = GGML_OP_UPSCALE;
+    result->op_params[0] = scale_factor;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_upscale(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int scale_factor) {
+    return ggml_upscale_impl(ctx, a, scale_factor);
+}
+
 // ggml_flash_attn
 
 struct ggml_tensor * ggml_flash_attn(
@@ -7215,6 +7415,87 @@ struct ggml_tensor * ggml_win_unpart(
     return result;
 }
 
+// ggml_get_rel_pos
+
+struct ggml_tensor * ggml_get_rel_pos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   qh,
+        int                   kh) {
+    GGML_ASSERT(qh == kh);
+    GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
+
+    result->op   = GGML_OP_GET_REL_POS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = NULL;
+
+    return result;
+}
+
+// ggml_add_rel_pos
+
+static struct ggml_tensor * ggml_add_rel_pos_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_are_same_shape(pw, ph));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(pw));
+    GGML_ASSERT(ggml_is_contiguous(ph));
+    GGML_ASSERT(ph->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->ne[3] == a->ne[2]);
+    GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
+    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || pw->grad || ph->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
+
+    result->op   = GGML_OP_ADD_REL_POS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = pw;
+    result->src[2] = ph;
+
+    return result;
+}
+
+
+struct ggml_tensor * ggml_add_rel_pos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
+}
+
+struct ggml_tensor * ggml_add_rel_pos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
+}
+
 // gmml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -9718,6 +9999,72 @@ static void ggml_compute_forward_repeat_back(
     }
 }
 
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    const struct ggml_tensor * src1,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2++) {
+            if (i2 < ne02) { // src0
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
+
+                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
+                        *y = *x;
+                    }
+                }
+            } // src1
+            else {
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
+
+                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
+                        *y = *x;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat(
+    const struct ggml_compute_params* params,
+    const struct ggml_tensor* src0,
+    const struct ggml_tensor* src1,
+    struct ggml_tensor* dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_concat_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_abs
 
 static void ggml_compute_forward_abs_f32(
@@ -10321,6 +10668,8 @@ static void ggml_compute_forward_norm(
     }
 }
 
+// ggml_compute_forward_group_rms_norm
+
 static void ggml_compute_forward_rms_norm_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10385,7 +10734,6 @@ static void ggml_compute_forward_rms_norm(
     }
 }
 
-
 static void ggml_compute_forward_rms_norm_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10559,42 +10907,132 @@ static void ggml_compute_forward_rms_norm_back(
     }
 }
 
-// ggml_compute_forward_mul_mat
+// ggml_compute_forward_group_norm
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-// helper function to determine if it is better to use BLAS or not
-// for large matrices, BLAS is faster
-static bool ggml_compute_forward_mul_mat_use_blas(
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    //const int64_t ne00 = src0->ne[0];
-    //const int64_t ne01 = src0->ne[1];
+static void ggml_compute_forward_group_norm_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    const int64_t ne10 = src1->ne[0];
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
 
-    // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
-        return true;
-    }
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
-    return false;
-}
-#endif
+    const float eps = 1e-6f; // TODO: make this a parameter
 
-static void ggml_compute_forward_mul_mat(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
+    // TODO: optimize
+
+    int n_channels = src0->ne[2];
+    int n_groups = dst->op_params[0];
+    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+    for (int i = ith; i < n_groups; i+=nth) {
+        int start = i * n_channels_per_group;
+        int end = start + n_channels_per_group;
+        if (end > n_channels) {
+            end = n_channels;
+        }
+        int step = end - start;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            ggml_float sum = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        sum += (ggml_float)x[i00];
+                    }
+                }
+            }
+            float mean = sum / (ne00 * ne01 * step);
+            ggml_float sum2 = 0.0;
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        float v = x[i00] - mean;
+                        y[i00] = v;
+                        sum2 += (ggml_float)(v * v);
+                    }
+                }
+            }
+            float variance = sum2 / (ne00 * ne01 * step);
+            const float scale = 1.0f / sqrtf(variance + eps);
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+                    ggml_vec_scale_f32(ne00, y, scale);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_group_norm(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_group_norm_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mul_mat
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_compute_forward_mul_mat_use_blas(
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    //const int64_t ne00 = src0->ne[0];
+    //const int64_t ne01 = src0->ne[1];
+
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
+        return true;
+    }
+
+    return false;
+}
+#endif
+
+static void ggml_compute_forward_mul_mat(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
     GGML_TENSOR_BINARY_OP_LOCALS;
@@ -10625,6 +11063,10 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
@@ -10644,11 +11086,6 @@ static void ggml_compute_forward_mul_mat(
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
-        //       ref: https://github.com/ggerganov/ggml/pull/224
-        GGML_ASSERT(ne02 == ne12);
-        GGML_ASSERT(ne03 == ne13);
-
         if (params->ith != 0) {
             return;
         }
@@ -10661,12 +11098,16 @@ static void ggml_compute_forward_mul_mat(
             return;
         }
 
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+        for (int64_t i13 = 0; i13 < ne13; i13++) {
+            for (int64_t i12 = 0; i12 < ne12; i12++) {
+                // broadcast src0 into src1 across 2nd,3rd dimension
+                const int64_t i03 = i13/r3;
+                const int64_t i02 = i12/r2;
 
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+                const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
+                const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+
+                float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
 
                 if (type != GGML_TYPE_F32) {
                             float * const wdata    = params->wdata;
@@ -10674,7 +11115,7 @@ static void ggml_compute_forward_mul_mat(
 
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        to_float((const char *) x + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
 
@@ -10754,10 +11195,6 @@ static void ggml_compute_forward_mul_mat(
     assert(ne12 % ne02 == 0);
     assert(ne13 % ne03 == 0);
 
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
     // block-tiling attempt
     const int64_t blck_0 = 16;
     const int64_t blck_1 = 16;
@@ -11913,7 +12350,6 @@ static void ggml_compute_forward_alibi(
     }
 }
 
-
 // ggml_compute_forward_clamp
 
 static void ggml_compute_forward_clamp_f32(
@@ -12002,12 +12438,18 @@ static void ggml_compute_forward_rope_f32(
     float freq_base;
     float freq_scale;
 
+    // these two only relevant for xPos RoPE:
+    float xpos_base;
+    bool xpos_down;
+
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3];
     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
 
     assert(n_past >= 0);
 
@@ -12079,6 +12521,9 @@ static void ggml_compute_forward_rope_f32(
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
+                        // zeta scaling for xPos only:
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        if (xpos_down) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
 
@@ -12088,8 +12533,8 @@ static void ggml_compute_forward_rope_f32(
                         const float x0 = src[0];
                         const float x1 = src[1];
 
-                        dst_data[0] = x0*cos_theta - x1*sin_theta;
-                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                        dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
+                        dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
                     // TODO: this is probably wrong, but I can't figure it out ..
@@ -12283,9 +12728,21 @@ static void ggml_compute_forward_rope_back_f32(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
+    float freq_base;
+    float freq_scale;
+
+    // these two only relevant for xPos RoPE:
+    float xpos_base;
+    bool xpos_down;
+
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
 
     assert(n_past >= 0);
 
@@ -12311,7 +12768,7 @@ static void ggml_compute_forward_rope_back_f32(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
 
@@ -12322,12 +12779,15 @@ static void ggml_compute_forward_rope_back_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
+                        // zeta scaling for xPos only:
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        if (xpos_down) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
 
@@ -12337,8 +12797,8 @@ static void ggml_compute_forward_rope_back_f32(
                         const float dy0 = dy[0];
                         const float dy1 = dy[1];
 
-                        dx[0] =   dy0*cos_theta + dy1*sin_theta;
-                        dx[1] = - dy0*sin_theta + dy1*cos_theta;
+                        dx[0] =   dy0*cos_theta*zeta + dy1*sin_theta*zeta;
+                        dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
                     }
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
@@ -13031,6 +13491,108 @@ static void ggml_compute_forward_conv_2d(
     }
 }
 
+// ggml_compute_forward_conv_transpose_2d
+
+static void ggml_compute_forward_conv_transpose_2d(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02*ne03;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+                    for (int64_t i01 = 0; i01 < ne01; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
+                        }
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            for (int i12 = 0; i12 < ne12; i12++) {
+                for (int i11 = 0; i11 < ne11; i11++) {
+                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
+                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+                    for (int i10 = 0; i10 < ne10; i10++) {
+                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
+                    }
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int32_t stride = ((const int32_t*)(opt0->data))[0];
+
+    // total patches in dst
+    const int np = ne2;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = (ggml_fp16_t *) params->wdata + nk;
+
+    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
+        float * dst_data = (float *)((char *) dst->data + i2*nb2);
+        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+        for (int i11 = 0; i11 < ne11; i11++) {
+            for (int i10 = 0; i10 < ne10; i10++) {
+                const int i1n = i11*ne10*ne12 + i10*ne12;
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        float v = 0;
+                        ggml_vec_dot_f16(ne03, &v,
+                                (ggml_fp16_t *) wdata_src + i1n,
+                                (ggml_fp16_t *) wdata_kernel + i01*ne00*ne03 + i00*ne03);
+
+                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
+                    }
+                }
+            }
+        }
+    }
+}
+
 // ggml_compute_forward_pool_1d_sk_p0
 
 static void ggml_compute_forward_pool_1d_sk_p0(
@@ -13189,6 +13751,60 @@ static void ggml_compute_forward_pool_2d(
     ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
 }
 
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    const int scale_factor = dst->op_params[0];
+
+    // TODO: optimize
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = ith; i02 < ne02; i02++) {
+            for (int m = 0; m < dst->ne[1]; m++) {
+                int i01 = m / scale_factor;
+                for (int n = 0; n < dst->ne[0]; n++) {
+                    int i00 = n / scale_factor;
+
+                    const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_upscale(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_upscale_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
 
 // ggml_compute_forward_flash_attn
 
@@ -14314,6 +14930,137 @@ static void ggml_compute_forward_unary(
     }
 }
 
+// ggml_compute_forward_get_rel_pos
+
+static void ggml_compute_forward_get_rel_pos_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    const int64_t w = ne1;
+
+    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
+    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            const int64_t pos = (w - i1 - 1) + i2;
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rel_pos(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rel_pos_f16(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_add_rel_pos
+
+static void ggml_compute_forward_add_rel_pos_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * src2,
+        struct ggml_tensor * dst) {
+
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    if (!inplace && params->type == GGML_TASK_INIT) {
+        memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+        return;
+    }
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+
+    float * src1_data = (float *) src1->data;
+    float * src2_data = (float *) src2->data;
+    float * dst_data  = (float *) dst->data;
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // total patches in dst
+    const int np = ne13;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+
+    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
+        for (int64_t i12 = 0; i12 < ne12; ++i12) {
+            for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
+                for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                    const int64_t jp0  = jp1 + i10;
+                    const float src1_e = src1_data[jp0];
+                    const float src2_e = src2_data[jp0];
+
+                    const int64_t jdh = jp0 * ne10;
+                    const int64_t jdw = jdh - (ne10 - 1) * i10;
+
+                    for (int64_t j = 0; j < ne10; ++j) {
+                        dst_data[jdh + j     ] += src2_e;
+                        dst_data[jdw + j*ne10] += src1_e;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_rel_pos(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * src2,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -14866,6 +15613,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_CONCAT:
+            {
+                ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor);
+            } break;
         case GGML_OP_SILU_BACK:
             {
                 ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
@@ -14882,6 +15633,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                ggml_compute_forward_group_norm(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
@@ -14974,6 +15729,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+            } break;
         case GGML_OP_POOL_1D:
             {
                 ggml_compute_forward_pool_1d(params, tensor->src[0], tensor);
@@ -14982,6 +15741,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pool_2d(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_UPSCALE:
+            {
+                ggml_compute_forward_upscale(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -15012,6 +15775,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_unary(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_GET_REL_POS:
+            {
+                ggml_compute_forward_get_rel_pos(params, tensor->src[0], tensor);
+            } break;
+        case GGML_OP_ADD_REL_POS:
+            {
+                ggml_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -15275,6 +16046,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             inplace);
                 }
             } break;
+        case GGML_OP_CONCAT:
+            {
+                GGML_ASSERT(false); // TODO: implement
+            } break;
         case GGML_OP_SILU_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -15297,6 +16072,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 // https://cs231n.github.io/optimization-2/#staged
@@ -15571,6 +16350,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
+                    float freq_base;
+                    float freq_scale;
+                    float xpos_base;
+                    bool  xpos_down;
+                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
@@ -15578,7 +16366,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_past,
                                 n_dims,
                                 mode,
-                                n_ctx),
+                                n_ctx,
+                                freq_base,
+                                freq_scale,
+                                xpos_base,
+                                xpos_down),
                             inplace);
                 }
             } break;
@@ -15589,14 +16381,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
+                    float freq_base;
+                    float freq_scale;
+                    float xpos_base;
+                    bool  xpos_down;
+                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
-                            ggml_rope(ctx,
+                            ggml_rope_impl(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
                                 mode,
-                                n_ctx),
+                                n_ctx,
+                                freq_base,
+                                freq_scale,
+                                xpos_base,
+                                xpos_down,
+                                false),
                             inplace);
                 }
             } break;
@@ -15616,6 +16422,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_POOL_1D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -15624,6 +16434,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_UPSCALE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15865,6 +16679,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         GGML_ASSERT(false);
                 }
             } break;
+        case GGML_OP_GET_REL_POS:
+        case GGML_OP_ADD_REL_POS:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
@@ -16441,9 +17257,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_NORM:
             case GGML_OP_RMS_NORM:
             case GGML_OP_RMS_NORM_BACK:
+            case GGML_OP_GROUP_NORM:
                 {
                     n_tasks = n_threads;
                 } break;
+            case GGML_OP_CONCAT:
             case GGML_OP_MUL_MAT:
             case GGML_OP_OUT_PROD:
                 {
@@ -16511,6 +17329,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_SOFT_MAX_BACK:
             case GGML_OP_ROPE:
             case GGML_OP_ROPE_BACK:
+            case GGML_OP_ADD_REL_POS:
                 {
                     n_tasks = n_threads;
                 } break;
@@ -16585,6 +17404,25 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
 
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_CONV_TRANSPOSE_2D:
+                {
+                    n_tasks = n_threads;
+
+                    const int64_t ne00 = node->src[0]->ne[0]; // W
+                    const int64_t ne01 = node->src[0]->ne[1]; // H
+                    const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
+                    const int64_t ne03 = node->src[0]->ne[3]; // Channels In
+
+                    const int64_t ne10 = node->src[1]->ne[0]; // W
+                    const int64_t ne11 = node->src[1]->ne[1]; // H
+                    const int64_t ne12 = node->src[1]->ne[2]; // Channels In
+
+                    size_t cur = 0;
+                    cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
+                    cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
+
                     work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_POOL_1D:
@@ -16592,6 +17430,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 {
                     n_tasks = 1;
                 } break;
+            case GGML_OP_UPSCALE:
+                {
+                    n_tasks = n_threads;
+                } break;
             case GGML_OP_FLASH_ATTN:
                 {
                     n_tasks = n_threads;
@@ -16653,6 +17495,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_WIN_PART:
             case GGML_OP_WIN_UNPART:
+            case GGML_OP_GET_REL_POS:
             case GGML_OP_MAP_UNARY:
             case GGML_OP_MAP_BINARY:
             case GGML_OP_MAP_CUSTOM1_F32:
@@ -16770,8 +17613,10 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
 
             const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
             GGML_ASSERT(rc == 0);
+            UNUSED(rc);
         }
     }
+
     workers[0].ith = 0;
     workers[0].shared = &state_shared;
 
diff --git a/ggml.h b/ggml.h
index 0ec7ec5bf95231d9de7d948c4bd2e1397bc1bc23..3c48fd27fab39d2986f76f3fc4a3b1543a6cb915 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 
+
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 
@@ -345,10 +346,12 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
+        GGML_OP_CONCAT,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM_BACK,
+        GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
         GGML_OP_OUT_PROD,
@@ -374,14 +377,19 @@ extern "C" {
         GGML_OP_CLAMP,
         GGML_OP_CONV_1D,
         GGML_OP_CONV_2D,
+        GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
 
+        GGML_OP_UPSCALE, // nearest interpolate
+
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
+        GGML_OP_GET_REL_POS,
+        GGML_OP_ADD_REL_POS,
 
         GGML_OP_UNARY,
 
@@ -805,6 +813,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // concat a and b on dim 2
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_concat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -913,6 +928,19 @@ extern "C" {
             struct ggml_tensor  * a,
             float                 eps);
 
+    // group normalize along ne0*ne1*n_groups
+    // used in stable-diffusion
+    // TODO: eps is hardcoded to 1e-6 for now
+    GGML_API struct ggml_tensor * ggml_group_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
+    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
     // a - x
     // b - dy
     // TODO: update with configurable eps
@@ -1213,6 +1241,15 @@ extern "C" {
             float                 freq_base,
             float                 freq_scale);
 
+    // xPos RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            float                 base,
+            bool                  down);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
@@ -1221,7 +1258,11 @@ extern "C" {
             int                   n_past,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx);
+            int                   n_ctx,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 xpos_base,
+            bool                  xpos_down);
 
     // alibi position embedding
     // in-place, returns view(a)
@@ -1248,6 +1289,15 @@ extern "C" {
             int                   p0,  // padding
             int                   d0); // dilation
 
+    // conv_1d with padding = half
+    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s,
+            int                   d);
+
     GGML_API struct ggml_tensor * ggml_conv_2d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1259,14 +1309,38 @@ extern "C" {
             int                   d0,
             int                   d1);
 
-    // conv_1d with padding = half
-    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
-    GGML_API struct ggml_tensor * ggml_conv_1d_ph(
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is equal to kernel size
+    // padding is zero
+    // example:
+    // a:     16   16    3  768
+    // b:   1024 1024    3    1
+    // res:   64   64  768    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is 1
+    // padding is half
+    // example:
+    // a:      3    3    256  256
+    // b:     64   64    256    1
+    // res:   64   64    256    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
-            int                   s,
-            int                   d);
+            int                   stride);
 
     enum ggml_op_pool {
         GGML_OP_POOL_MAX,
@@ -1293,6 +1367,13 @@ extern "C" {
             int                   p0,
             int                   p1);
 
+    // nearest interpolate
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_upscale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   scale_factor);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1346,6 +1427,27 @@ extern "C" {
         struct ggml_tensor  * a,
         enum ggml_unary_op op);
 
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_get_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   qh,
+            int                   kh);
+
+    // used in sam
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1500,6 +1602,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * tensor);
 
+
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
index 3d13e852a4d24e1db7c01f6051e7cbd56247dabf..e44c3bd03fa93f02bd6a5ec02b2d124ee5a7c22c 100755 (executable)
@@ -1,14 +1,16 @@
 #!/bin/bash
 
-cp -rpv ../ggml/src/ggml.c           ./ggml.c
-cp -rpv ../ggml/src/ggml-cuda.h      ./ggml-cuda.h
-cp -rpv ../ggml/src/ggml-cuda.cu     ./ggml-cuda.cu
-cp -rpv ../ggml/src/ggml-opencl.h    ./ggml-opencl.h
-cp -rpv ../ggml/src/ggml-opencl.cpp  ./ggml-opencl.cpp
-cp -rpv ../ggml/src/ggml-metal.h     ./ggml-metal.h
-cp -rpv ../ggml/src/ggml-metal.m     ./ggml-metal.m
-cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
-cp -rpv ../ggml/include/ggml/ggml.h  ./ggml.h
+cp -rpv ../ggml/src/ggml.c                ./ggml.c
+cp -rpv ../ggml/src/ggml-alloc.c          ./ggml-alloc.c
+cp -rpv ../ggml/src/ggml-cuda.h           ./ggml-cuda.h
+cp -rpv ../ggml/src/ggml-cuda.cu          ./ggml-cuda.cu
+cp -rpv ../ggml/src/ggml-opencl.h         ./ggml-opencl.h
+cp -rpv ../ggml/src/ggml-opencl.cpp       ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-metal.h          ./ggml-metal.h
+cp -rpv ../ggml/src/ggml-metal.m          ./ggml-metal.m
+cp -rpv ../ggml/src/ggml-metal.metal      ./ggml-metal.metal
+cp -rpv ../ggml/include/ggml/ggml.h       ./ggml.h
+cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
 
 cp -rpv ../ggml/tests/test-opt.cpp    ./tests/test-opt.cpp
 cp -rpv ../ggml/tests/test-grad0.cpp  ./tests/test-grad0.cpp