]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: move op parameters from tensors to ggml_tensor::op_params (#2333)
authorslaren <redacted>
Sun, 23 Jul 2023 12:36:02 +0000 (14:36 +0200)
committerGitHub <redacted>
Sun, 23 Jul 2023 12:36:02 +0000 (14:36 +0200)
* ggml: move op parameters from tensors to ggml_tensor::op_params

* alibi: use memcpy for float params

* remove `src[1] = NULL` in ops

ggml-cuda.cu
ggml-metal.m
ggml.c
ggml.h

index 720447440183083e0877039ace008beaf03e6eff..6fb55d838dfb3e4a7ec505f17a3adf334e085ffe 100644 (file)
@@ -2742,6 +2742,7 @@ inline void ggml_cuda_op_mul(
     (void) dst;
     (void) src0_ddq_i;
     (void) i02;
+    (void) i1;
 }
 
 inline void ggml_cuda_op_gelu(
@@ -3037,15 +3038,15 @@ inline void ggml_cuda_op_rope(
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
-
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
     // RoPE alteration for extended context
+
     float freq_base, freq_scale;
-    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
     const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
@@ -3061,6 +3062,7 @@ inline void ggml_cuda_op_rope(
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
     }
 
+    (void) src1;
     (void) dst;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
@@ -3079,11 +3081,12 @@ inline void ggml_cuda_op_diag_mask_inf(
     const int64_t ne01 = src0->ne[1];
     const int64_t i01_diff = i01_high - i01_low;
 
-    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_past = ((int32_t *) dst->op_params)[0];
 
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
 
+    (void) src1;
     (void) dst;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
@@ -3803,7 +3806,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
         if (tensor->op == GGML_OP_VIEW) {
-            memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
+            memcpy(&offset, tensor->op_params, sizeof(size_t));
         }
         extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src0_ddc + offset;
index 78a3b65f1959265a617a69f8a0e90f9e626fa1fa..bf3f68fe45726ca93a10b9293200d11c717e7789 100644 (file)
@@ -585,7 +585,7 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *)(dst->op_params))[0];
 
                             [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -850,9 +850,10 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
 
-                            const int   n_past   = ((int32_t *) src1->data)[0]; UNUSED(n_past);
-                            const int   n_head   = ((int32_t *) src1->data)[1];
-                            const float max_bias = ((float *)   src1->data)[2];
+                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            const int n_head = ((int32_t *) dst->op_params)[1];
+                            float max_bias;
+                            memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
                             if (__builtin_popcount(n_head) != 1) {
                                 GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -890,15 +891,14 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            const int n_dims = ((int32_t *) src1->data)[1];
-                            const int mode   = ((int32_t *) src1->data)[2];
-
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *) dst->op_params)[0];
+                            const int n_dims = ((int32_t *) dst->op_params)[1];
+                            const int mode   = ((int32_t *) dst->op_params)[2];
 
                             float freq_base;
                             float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
diff --git a/ggml.c b/ggml.c
index 6055da867cb27ed1e3df51c0bae84e04571c28ab..747a39241b0b6fae1bbca6b84358f529335178f1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4590,6 +4590,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
         /*.op           =*/ GGML_OP_NONE,
+        /*.op_params    =*/ {0},
         /*.is_param     =*/ false,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
@@ -4969,6 +4970,11 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char *
     return tensor;
 }
 
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5019,7 +5025,6 @@ struct ggml_tensor * ggml_dup_impl(
     result->op   = GGML_OP_DUP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5143,23 +5148,13 @@ struct ggml_tensor * ggml_acc_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
-
-    ((int32_t *) c->data)[0] = nb1;
-    ((int32_t *) c->data)[1] = nb2;
-    ((int32_t *) c->data)[2] = nb3;
-    ((int32_t *) c->data)[3] = offset;
-    ((int32_t *) c->data)[4] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ACC;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -5332,7 +5327,6 @@ struct ggml_tensor * ggml_sqr_impl(
     result->op   = GGML_OP_SQR;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5366,7 +5360,6 @@ struct ggml_tensor * ggml_sqrt_impl(
     result->op   = GGML_OP_SQRT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5401,7 +5394,6 @@ struct ggml_tensor * ggml_log_impl(
     result->op   = GGML_OP_LOG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5434,7 +5426,6 @@ struct ggml_tensor * ggml_sum(
     result->op   = GGML_OP_SUM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5461,7 +5452,6 @@ struct ggml_tensor * ggml_sum_rows(
     result->op   = GGML_OP_SUM_ROWS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5484,7 +5474,6 @@ struct ggml_tensor * ggml_mean(
     result->op   = GGML_OP_MEAN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5508,7 +5497,6 @@ struct ggml_tensor * ggml_argmax(
     result->op   = GGML_OP_ARGMAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5586,7 +5574,6 @@ struct ggml_tensor * ggml_abs_impl(
     result->op   = GGML_OP_ABS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5621,7 +5608,6 @@ struct ggml_tensor * ggml_sgn_impl(
     result->op   = GGML_OP_SGN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5655,7 +5641,6 @@ struct ggml_tensor * ggml_neg_impl(
     result->op   = GGML_OP_NEG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5689,7 +5674,6 @@ struct ggml_tensor * ggml_step_impl(
     result->op   = GGML_OP_STEP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5723,7 +5707,6 @@ struct ggml_tensor * ggml_tanh_impl(
     result->op   = GGML_OP_TANH;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5757,7 +5740,6 @@ struct ggml_tensor * ggml_elu_impl(
     result->op   = GGML_OP_ELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5791,7 +5773,6 @@ struct ggml_tensor * ggml_relu_impl(
     result->op   = GGML_OP_RELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5825,7 +5806,6 @@ struct ggml_tensor * ggml_gelu_impl(
     result->op   = GGML_OP_GELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5859,7 +5839,6 @@ struct ggml_tensor * ggml_gelu_quick_impl(
     result->op   = GGML_OP_GELU_QUICK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5893,7 +5872,6 @@ struct ggml_tensor * ggml_silu_impl(
     result->op   = GGML_OP_SILU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5948,10 +5926,11 @@ struct ggml_tensor * ggml_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    // TODO: maybe store epsilon here?
+
     result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
@@ -5980,10 +5959,11 @@ struct ggml_tensor * ggml_rms_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    // TODO: maybe store epsilon here?
+
     result->op   = GGML_OP_RMS_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
@@ -6136,23 +6116,13 @@ struct ggml_tensor * ggml_set_impl(
     // make a view of the destination
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
-
-    (( int32_t * ) c->data)[0] = nb1;
-    (( int32_t * ) c->data)[1] = nb2;
-    (( int32_t * ) c->data)[2] = nb3;
-    (( int32_t * ) c->data)[3] = offset;
-    (( int32_t * ) c->data)[4] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_SET;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -6277,7 +6247,6 @@ struct ggml_tensor * ggml_cont_impl(
     result->op   = GGML_OP_CONT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6321,7 +6290,6 @@ struct ggml_tensor * ggml_reshape(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6346,7 +6314,6 @@ struct ggml_tensor * ggml_reshape_1d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6372,7 +6339,6 @@ struct ggml_tensor * ggml_reshape_2d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6399,7 +6365,6 @@ struct ggml_tensor * ggml_reshape_3d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6428,7 +6393,6 @@ struct ggml_tensor * ggml_reshape_4d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6450,19 +6414,11 @@ struct ggml_tensor * ggml_view_1d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6488,13 +6444,7 @@ struct ggml_tensor * ggml_view_2d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
@@ -6503,8 +6453,6 @@ struct ggml_tensor * ggml_view_2d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6532,13 +6480,7 @@ struct ggml_tensor * ggml_view_3d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6547,8 +6489,6 @@ struct ggml_tensor * ggml_view_3d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6578,13 +6518,7 @@ struct ggml_tensor * ggml_view_4d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6593,8 +6527,6 @@ struct ggml_tensor * ggml_view_4d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6655,22 +6587,9 @@ struct ggml_tensor * ggml_permute(
     result->op   = GGML_OP_PERMUTE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-
-    if (is_node) {
-        ggml_scratch_save(ctx);
 
-        struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
-
-        ((int32_t *) b->data)[0] = axis0;
-        ((int32_t *) b->data)[1] = axis1;
-        ((int32_t *) b->data)[2] = axis2;
-        ((int32_t *) b->data)[3] = axis3;
-
-        ggml_scratch_load(ctx);
-
-        result->src[2] = b;
-    }
+    int32_t params[] = { axis0, axis1, axis2, axis3 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     return result;
 }
@@ -6698,7 +6617,6 @@ struct ggml_tensor * ggml_transpose(
     result->op   = GGML_OP_TRANSPOSE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6776,7 +6694,6 @@ struct ggml_tensor * ggml_diag(
     result->op   = GGML_OP_DIAG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6797,19 +6714,12 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6844,20 +6754,12 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(b, "n_past, inplace");
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_ZERO;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6893,7 +6795,6 @@ struct ggml_tensor * ggml_soft_max_impl(
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6969,23 +6870,14 @@ struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_dims;
-    ((int32_t *) b->data)[2] = mode;
-    ((int32_t *) b->data)[3] = n_ctx;
-    memcpy((int32_t *) b->data + 4, &freq_base,  sizeof(float));
-    memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
-
-    ggml_scratch_load(ctx);
+    int32_t params[6] = { n_past, n_dims, mode, n_ctx };
+    memcpy(params + 4, &freq_base, sizeof(float));
+    memcpy(params + 5, &freq_scale, sizeof(float));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7042,22 +6934,12 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
-    ggml_set_name(b, "n_past, n_dims, mode");
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_dims;
-    ((int32_t *) b->data)[2] = mode;
-    ((int32_t *) b->data)[3] = n_ctx;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, n_dims, mode, n_ctx };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7082,21 +6964,13 @@ struct ggml_tensor * ggml_alibi(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_head;
-    GGML_ASSERT(sizeof(float) == sizeof(int32_t));
-    (((float *) b->data)[2]) = bias_max;
-
-    ggml_scratch_load(ctx);
+    int32_t op_params[3] = { n_past, n_head };
+    memcpy(op_params + 2, &bias_max, sizeof(float));
+    ggml_set_op_params(result, &op_params, sizeof(op_params));
 
     result->op   = GGML_OP_ALIBI;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7118,19 +6992,12 @@ struct ggml_tensor * ggml_clamp(
     // TODO: when implement backward, fix this:
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
-
-    ((float *) b->data)[0] = min;
-    ((float *) b->data)[1] = max;
-
-    ggml_scratch_load(ctx);
+    float params[] = { min, max };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_CLAMP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7163,18 +7030,13 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-    ((int32_t*)c->data)[0] = s0;
-    ((int32_t*)c->data)[1] = p0;
-    ((int32_t*)c->data)[2] = d0;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { s0, p0, d0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_CONV_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -7207,21 +7069,13 @@ struct ggml_tensor* ggml_conv_2d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
-    ((int32_t*)c->data)[0] = s0;
-    ((int32_t*)c->data)[1] = s1;
-    ((int32_t*)c->data)[2] = p0;
-    ((int32_t*)c->data)[3] = p1;
-    ((int32_t*)c->data)[4] = d0;
-    ((int32_t*)c->data)[5] = d1;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_CONV_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 
@@ -7245,7 +7099,7 @@ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
     return (ins + 2 * p - ks) / s + 1;
 }
 
-// ggml_pool_2d
+// ggml_pool_1d
 
 struct ggml_tensor* ggml_pool_1d(
         struct ggml_context * ctx,
@@ -7268,18 +7122,12 @@ struct ggml_tensor* ggml_pool_1d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
-    ((int32_t*)c->data)[0] = op;
-    ((int32_t*)c->data)[1] = k0;
-    ((int32_t*)c->data)[2] = s0;
-    ((int32_t*)c->data)[3] = p0;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { op, k0, s0, p0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_POOL_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = c;
 
     return result;
 }
@@ -7311,21 +7159,12 @@ struct ggml_tensor* ggml_pool_2d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
-    ((int32_t*)c->data)[0] = op;
-    ((int32_t*)c->data)[1] = k0;
-    ((int32_t*)c->data)[2] = k1;
-    ((int32_t*)c->data)[3] = s0;
-    ((int32_t*)c->data)[4] = s1;
-    ((int32_t*)c->data)[5] = p0;
-    ((int32_t*)c->data)[6] = p1;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_POOL_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = c;
 
     return result;
 }
@@ -7484,21 +7323,12 @@ struct ggml_tensor * ggml_win_part(
 
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-
-    ((int32_t *) b->data)[0] = npx;
-    ((int32_t *) b->data)[1] = npy;
-    ((int32_t *) b->data)[2] = w;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { npx, npy, w };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_WIN_PART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = b;
 
     return result;
 }
@@ -7523,19 +7353,12 @@ struct ggml_tensor * ggml_win_unpart(
     const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-
-    ((int32_t *) b->data)[0] = w;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { w };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_WIN_UNPART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = b;
 
     return result;
 }
@@ -7553,19 +7376,13 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_UNARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7600,20 +7417,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_BINARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7647,19 +7458,13 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM1;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7692,20 +7497,14 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM2;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7741,21 +7540,15 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM3;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
-    result->src[3] = c;
+    result->src[2] = c;
 
     return result;
 }
@@ -8983,21 +8776,17 @@ static void ggml_compute_forward_acc_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
-    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(opt0) == 5);
-
     // view src0 and dst with these strides and data offset inbytes during acc
     // nb0 is implicitely element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) opt0->data)[0];
-    size_t nb2     = ((int32_t *) opt0->data)[1];
-    size_t nb3     = ((int32_t *) opt0->data)[2];
-    size_t offset  = ((int32_t *) opt0->data)[3];
-    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -9066,13 +8855,12 @@ static void ggml_compute_forward_acc(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_acc_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -11092,21 +10880,17 @@ static void ggml_compute_forward_set_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
-    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(opt0) == 5);
-
     // view src0 and dst with these strides and data offset inbytes during set
     // nb0 is implicitely element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) opt0->data)[0];
-    size_t nb2     = ((int32_t *) opt0->data)[1];
-    size_t nb3     = ((int32_t *) opt0->data)[2];
-    size_t offset  = ((int32_t *) opt0->data)[3];
-    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -11166,13 +10950,12 @@ static void ggml_compute_forward_set(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_set_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_set_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -11568,17 +11351,14 @@ static void ggml_compute_forward_diag(
 static void ggml_compute_forward_diag_mask_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst,
         const float value) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int  n_past  =       ((int32_t *) src1->data)[0];
-    const bool inplace = (bool)((int32_t *) src1->data)[1];
+    const int  n_past  =       ((int32_t *) dst->op_params)[0];
+    const bool inplace = (bool)((int32_t *) dst->op_params)[1];
 
     GGML_ASSERT(n_past >= 0);
 
@@ -11621,12 +11401,11 @@ static void ggml_compute_forward_diag_mask_f32(
 static void ggml_compute_forward_diag_mask_inf(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY);
+                ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY);
             } break;
         default:
             {
@@ -11638,12 +11417,11 @@ static void ggml_compute_forward_diag_mask_inf(
 static void ggml_compute_forward_diag_mask_zero(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0);
+                ggml_compute_forward_diag_mask_f32(params, src0, dst, 0);
             } break;
         default:
             {
@@ -11841,20 +11619,17 @@ static void ggml_compute_forward_soft_max_back(
 static void ggml_compute_forward_alibi_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -11907,20 +11682,17 @@ static void ggml_compute_forward_alibi_f32(
 static void ggml_compute_forward_alibi_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -11973,16 +11745,15 @@ static void ggml_compute_forward_alibi_f16(
 static void ggml_compute_forward_alibi(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_alibi_f16(params, src0, src1, dst);
+                ggml_compute_forward_alibi_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_alibi_f32(params, src0, src1, dst);
+                ggml_compute_forward_alibi_f32(params, src0, dst);
             } break;
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -12012,19 +11783,17 @@ static void ggml_compute_forward_alibi(
 static void ggml_compute_forward_clamp_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_nelements(src1) == 2);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const float min = ((float *) src1->data)[0];
-    const float max = ((float *) src1->data)[1];
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -12054,12 +11823,11 @@ static void ggml_compute_forward_clamp_f32(
 static void ggml_compute_forward_clamp(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_clamp_f32(params, src0, src1, dst);
+                ggml_compute_forward_clamp_f32(params, src0, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -12089,10 +11857,7 @@ static void ggml_compute_forward_clamp(
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12101,12 +11866,12 @@ static void ggml_compute_forward_rope_f32(
     float freq_base;
     float freq_scale;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
-    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12221,10 +11986,7 @@ static void ggml_compute_forward_rope_f32(
 static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12233,12 +11995,12 @@ static void ggml_compute_forward_rope_f16(
     float freq_base;
     float freq_scale;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
-    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12353,16 +12115,15 @@ static void ggml_compute_forward_rope_f16(
 static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12376,10 +12137,7 @@ static void ggml_compute_forward_rope(
 static void ggml_compute_forward_rope_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 4);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12389,9 +12147,9 @@ static void ggml_compute_forward_rope_back_f32(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
 
     assert(n_past >= 0);
 
@@ -12475,10 +12233,7 @@ static void ggml_compute_forward_rope_back_f32(
 static void ggml_compute_forward_rope_back_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12488,9 +12243,9 @@ static void ggml_compute_forward_rope_back_f16(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
 
     assert(n_past >= 0);
 
@@ -12574,16 +12329,15 @@ static void ggml_compute_forward_rope_back_f16(
 static void ggml_compute_forward_rope_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_back_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_back_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12780,7 +12534,7 @@ static void ggml_compute_forward_conv_1d_s1_ph(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
@@ -12983,7 +12737,7 @@ static void ggml_compute_forward_conv_1d_s2_ph(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
@@ -13003,14 +12757,13 @@ static void ggml_compute_forward_conv_1d_s2_ph(
 // ggml_compute_forward_conv_1d
 
 static void ggml_compute_forward_conv_1d(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * src0,
-    const struct ggml_tensor * src1,
-    const struct ggml_tensor * opt0,
-    struct ggml_tensor * dst) {
-    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
-    const int32_t p0 = ((const int32_t*)(opt0->data))[1];
-    const int32_t d0 = ((const int32_t*)(opt0->data))[2];
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
     GGML_ASSERT(d0 == 1); // dilation not supported
     GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
     if (s0 == 1) {
@@ -13028,7 +12781,6 @@ static void ggml_compute_forward_conv_2d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13048,12 +12800,12 @@ static void ggml_compute_forward_conv_2d_f16_f32(
     // size of the convolution row - the kernel size unrolled across all channels
     const int ew0 = nk0*nk1*ne02;
 
-    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
-    const int32_t s1 = ((const int32_t*)(opt0->data))[1];
-    const int32_t p0 = ((const int32_t*)(opt0->data))[2];
-    const int32_t p1 = ((const int32_t*)(opt0->data))[3];
-    const int32_t d0 = ((const int32_t*)(opt0->data))[4];
-    const int32_t d1 = ((const int32_t*)(opt0->data))[5];
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -13125,17 +12877,15 @@ static void ggml_compute_forward_conv_2d(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
-        struct ggml_tensor * dst
-        ) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
+                //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
                 GGML_ASSERT(false);
             } break;
         default:
@@ -13200,12 +12950,11 @@ static void ggml_compute_forward_pool_1d_sk_p0(
 // ggml_compute_forward_pool_1d
 
 static void ggml_compute_forward_pool_1d(
-    const struct ggml_compute_params* params,
-    const struct ggml_tensor* src0,
-    const struct ggml_tensor* opt0,
-    struct ggml_tensor* dst) {
-    GGML_ASSERT(opt0->ne[0] == 4);
-    const int* opts = (const int*)opt0->data;
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+              struct ggml_tensor * dst) {
+
+    const int32_t* opts = (const int32_t*)dst->op_params;
     enum ggml_op_pool op = opts[0];
     const int k0 = opts[1];
     const int s0 = opts[2];
@@ -13219,12 +12968,12 @@ static void ggml_compute_forward_pool_1d(
 // ggml_compute_forward_pool_2d_sk_p0
 
 static void ggml_compute_forward_pool_2d_sk_p0(
-    const struct ggml_compute_params * params,
-    const enum ggml_op_pool op,
-    const struct ggml_tensor * src,
-    const int k0,
-    const int k1,
-    struct ggml_tensor * dst) {
+        const struct ggml_compute_params * params,
+        const enum   ggml_op_pool op,
+        const struct ggml_tensor * src,
+        const int k0,
+        const int k1,
+        struct ggml_tensor * dst) {
     assert(src->type == GGML_TYPE_F32);
     assert(params->ith == 0);
 
@@ -13284,12 +13033,11 @@ static void ggml_compute_forward_pool_2d_sk_p0(
 // ggml_compute_forward_pool_2d
 
 static void ggml_compute_forward_pool_2d(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * src0,
-    const struct ggml_tensor * opt0,
-    struct ggml_tensor * dst) {
-    GGML_ASSERT(opt0->ne[0] == 7);
-    const int* opts = (const int*)opt0->data;
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+              struct ggml_tensor * dst) {
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
     enum ggml_op_pool op = opts[0];
     const int k0 = opts[1];
     const int k1 = opts[2];
@@ -13314,7 +13062,7 @@ static void ggml_compute_forward_flash_attn_f32(
         const struct ggml_tensor * k,
         const struct ggml_tensor * v,
         const bool masked,
-             struct ggml_tensor * dst) {
+        struct ggml_tensor * dst) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -13492,7 +13240,7 @@ static void ggml_compute_forward_flash_attn_f16(
         const struct ggml_tensor * k,
         const struct ggml_tensor * v,
         const bool masked,
-             struct ggml_tensor * dst) {
+        struct ggml_tensor * dst) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -14257,7 +14005,6 @@ static void ggml_compute_forward_flash_attn_back(
 static void ggml_compute_forward_win_part_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -14266,9 +14013,9 @@ static void ggml_compute_forward_win_part_f32(
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
-    const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
-    const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
-    const int32_t w    = ((const int32_t *)(opt0->data))[2];
+    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
 
     assert(ne00 == ne0);
     assert(ne3  == nep0*nep1);
@@ -14302,12 +14049,11 @@ static void ggml_compute_forward_win_part_f32(
 static void ggml_compute_forward_win_part(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
+                ggml_compute_forward_win_part_f32(params, src0, dst);
             } break;
         default:
             {
@@ -14321,7 +14067,6 @@ static void ggml_compute_forward_win_part(
 static void ggml_compute_forward_win_unpart_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -14330,7 +14075,7 @@ static void ggml_compute_forward_win_unpart_f32(
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
-    const int32_t w = ((const int32_t *)(opt0->data))[0];
+    const int32_t w = ((const int32_t *)(dst->op_params))[0];
 
     // padding
     const int px = (w - ne1%w)%w;
@@ -14364,12 +14109,11 @@ static void ggml_compute_forward_win_unpart_f32(
 static void ggml_compute_forward_win_unpart(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
+                ggml_compute_forward_win_unpart_f32(params, src0, dst);
             } break;
         default:
             {
@@ -14888,7 +14632,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_ACC:
             {
-                ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_SUB:
             {
@@ -15008,7 +14752,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SET:
             {
-                ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CPY:
             {
@@ -15048,11 +14792,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_DIAG_MASK_INF:
             {
-                ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
             {
-                ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_SOFT_MAX:
             {
@@ -15064,35 +14808,35 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_ROPE:
             {
-                ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_rope(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_ROPE_BACK:
             {
-                ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_ALIBI:
             {
-                ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_alibi(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_CLAMP:
             {
-                ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_clamp(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_CONV_1D:
             {
-                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CONV_2D:
             {
-                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_POOL_1D:
             {
-                ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_pool_1d(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_POOL_2D:
             {
-                ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_pool_2d(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_FLASH_ATTN:
             {
@@ -15114,40 +14858,45 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_WIN_PART:
             {
-                ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
+                ggml_compute_forward_win_part(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_WIN_UNPART:
             {
-                ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
+                ggml_compute_forward_win_unpart(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_MAP_UNARY:
             {
-                const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
+                ggml_unary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
             }
             break;
         case GGML_OP_MAP_BINARY:
             {
-                const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
+                ggml_binary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM1:
             {
-                const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
+                ggml_custom1_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM2:
             {
-                const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
+                ggml_custom2_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM3:
             {
-                const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
-                ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
+                ggml_custom3_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun);
             }
             break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
@@ -15211,12 +14960,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
                 }
                 if (src1->grad) {
-                    GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
-                    GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
-                    const size_t nb1     = (( int32_t * ) tensor->src[2]->data)[0];
-                    const size_t nb2     = (( int32_t * ) tensor->src[2]->data)[1];
-                    const size_t nb3     = (( int32_t * ) tensor->src[2]->data)[2];
-                    const size_t offset  = (( int32_t * ) tensor->src[2]->data)[3];
+                    const size_t nb1     = ((int32_t *) tensor->op_params)[0];
+                    const size_t nb2     = ((int32_t *) tensor->op_params)[1];
+                    const size_t nb3     = ((int32_t *) tensor->op_params)[2];
+                    const size_t offset  = ((int32_t *) tensor->op_params)[3];
 
                     struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
                         tensor->grad,
@@ -15524,12 +15271,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_SET:
             {
-                GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
-                GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
-                const size_t nb1     = (( int32_t * ) tensor->src[2]->data)[0];
-                const size_t nb2     = (( int32_t * ) tensor->src[2]->data)[1];
-                const size_t nb3     = (( int32_t * ) tensor->src[2]->data)[2];
-                const size_t offset  = (( int32_t * ) tensor->src[2]->data)[3];
+                const size_t nb1     = ((int32_t *) tensor->op_params)[0];
+                const size_t nb2     = ((int32_t *) tensor->op_params)[1];
+                const size_t nb3     = ((int32_t *) tensor->op_params)[2];
+                const size_t offset  = ((int32_t *) tensor->op_params)[3];
 
                 struct ggml_tensor * tensor_grad_view = NULL;
 
@@ -15606,8 +15351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     size_t offset;
 
-                    GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
-                    memcpy(&offset, tensor->src[2]->data, sizeof(offset));
+                    memcpy(&offset, tensor->op_params, sizeof(offset));
 
                     size_t nb1     = tensor->nb[1];
                     size_t nb2     = tensor->nb[2];
@@ -15634,7 +15378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    int32_t * axes = (int32_t *) tensor->src[2]->data;
+                    int32_t * axes = (int32_t *) tensor->op_params;
                     int axis0 = axes[0] & 0x3;
                     int axis1 = axes[1] & 0x3;
                     int axis2 = axes[2] & 0x3;
@@ -15690,9 +15434,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 2);
-                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
                         ggml_add_impl(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
@@ -15706,9 +15448,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 2);
-                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
                         ggml_add_impl(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
@@ -15737,12 +15477,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 6);
-                    const int n_past = ((int32_t *) src1->data)[0];
-                    const int n_dims = ((int32_t *) src1->data)[1];
-                    const int mode   = ((int32_t *) src1->data)[2];
-                    const int n_ctx  = ((int32_t *) src1->data)[3];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    const int n_dims = ((int32_t *) tensor->op_params)[1];
+                    const int mode   = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
@@ -15760,12 +15498,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_ROPE_BACK:
             {
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 4);
-                    const int n_past = ((int32_t *) src1->data)[0];
-                    const int n_dims = ((int32_t *) src1->data)[1];
-                    const int mode   = ((int32_t *) src1->data)[2];
-                    const int n_ctx  = ((int32_t *) src1->data)[3];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    const int n_dims = ((int32_t *) tensor->op_params)[1];
+                    const int mode   = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope(ctx,
@@ -16543,9 +16279,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_GET_ROWS_BACK:
             case GGML_OP_DIAG:
             case GGML_OP_DIAG_MASK_ZERO:
-                {
-                    n_tasks = 1;
-                } break;
             case GGML_OP_DIAG_MASK_INF:
             case GGML_OP_SOFT_MAX:
             case GGML_OP_SOFT_MAX_BACK:
@@ -17289,7 +17022,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                             tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
 
                             uint64_t offs;
-                            memcpy(&offs, args[2]->data, sizeof(offs));
+                            memcpy(&offs, tensor->op_params, sizeof(offs));
 
                             tensor->data = ((char *) tensor->data) + offs;
                         } break;
diff --git a/ggml.h b/ggml.h
index 5023b165287888e1938e5498edc25694ff7214e1..871c85a89aae796f841fce9bd9bee0284929f2c6 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          48
+#define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 
 
@@ -418,6 +419,9 @@ extern "C" {
         // compute data
         enum ggml_op op;
 
+        // op params - allocated as int32_t for alignment
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
+
         bool is_param;
 
         struct ggml_tensor * grad;