]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : add opt_step_adamw and op_sum (llama/16529)
authorSam/Samuel <redacted>
Sun, 12 Oct 2025 18:43:14 +0000 (02:43 +0800)
committerGeorgi Gerganov <redacted>
Tue, 14 Oct 2025 19:07:44 +0000 (22:07 +0300)
* scaffold to support opt step adamw on metal (not written so far)

* add opt-step-adamw kernel for metal

* pass op->src[4] as a separate buffer to the pipeline

* add bounds check to opt-step-adamw kernel

* complete scaffold for GGML_OP_SUM

* naive GGML_OP_SUM kernel

* remove unwanted comment

* change OP_SUM capability gate

* Add has_simdgroup_reduction to both ops to pass CI

src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.metal

index e23abdda97405bc53af5d789f1b9d0b451af8ceb..335d5848e290c48d90ef4c282b049433a6cbe54f 100644 (file)
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_SUM);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
 
@@ -1482,3 +1501,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_OPT_STEP_ADAMW);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
index 1034e4bbf65960af5c157f75a8e031670044a898..283e70fa7910971c88998ff5c4f9255357845d83 100644 (file)
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows          (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -134,6 +135,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad               (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
 
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
index 95279730152455a9a2839fbc9b83dea681281194..e38e70768040a1b8b212f0b2c04aa2a04f3a099a 100644 (file)
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_COS:
         case GGML_OP_LOG:
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
@@ -798,6 +799,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                         return false;
                 };
             }
+        case GGML_OP_OPT_STEP_ADAMW:
+            return has_simdgroup_reduction;
         default:
             return false;
     }
index c9dff873058697d8fdbd75c6e16191821bffe475..c4c9f0a7f6aefc64fb86790642f5078ce499cb5f 100644 (file)
@@ -544,6 +544,10 @@ typedef struct{
     float    limit;
 } ggml_metal_kargs_glu;
 
+typedef struct {
+    uint64_t np;
+} ggml_metal_kargs_sum;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
@@ -773,4 +777,8 @@ typedef struct {
     uint64_t nb01;
 } ggml_metal_kargs_argmax;
 
+typedef struct {
+    int64_t  np;
+} ggml_metal_kargs_opt_step_adamw;
+
 #endif // GGML_METAL_IMPL
index 5f9370449bb2deb48e7864c6f583152f4d77c0c4..c01c0b181e8f56ca3961077660babb1934602c62 100644 (file)
@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_glu(ctx, idx);
             } break;
+        case GGML_OP_SUM:
+            {
+                n_fuse = ggml_metal_op_sum(ctx, idx);
+            } break;
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
             {
@@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_argmax(ctx, idx);
             } break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
+            } break;
        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op  = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
+
+    ggml_metal_kargs_sum args = {
+        /*.np =*/ n,
+    };
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
 
     return 1;
 }
+
+int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
+
+    const int64_t np = ggml_nelements(op->src[0]);
+    ggml_metal_kargs_opt_step_adamw args = {
+        /*.np =*/ np,
+    };
+
+    int ida = 0;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
+
+    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+    const int64_t n = (np + nth - 1) / nth;
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
+
+    return 1;
+}
index d4cb9446212d90c5f89191878669dff0ce03a6a2..6641cf5dfcb52ad65fd98ea7839f2c562fcf80ca 100644 (file)
@@ -50,6 +50,7 @@ int ggml_metal_op_scale             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_clamp             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);
@@ -78,6 +79,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 
 #ifdef __cplusplus
 }
index ddc285042d284f081dbb357d719297f3419641b1..780d6a97350ebd5fc5d0834fc06b686e44b7c144 100644 (file)
@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
     }
 }
 
+kernel void kernel_op_sum_f32(
+        constant ggml_metal_kargs_sum & args,
+        device const float * src0,
+        device       float * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]]) {
+
+    if (tiitg != 0) {
+        return;
+    }
+
+    float acc = 0.0f;
+    for (ulong i = 0; i < args.np; ++i) {
+        acc += src0[i];
+    }
+
+    dst[0] = acc;
+}
+
 template <bool norm>
 kernel void kernel_sum_rows(
         constant ggml_metal_kargs_sum_rows & args,
@@ -8754,3 +8772,37 @@ kernel void kernel_pool_2d_avg_f32(
 
     o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
+
+kernel void kernel_opt_step_adamw_f32(
+        constant    ggml_metal_kargs_opt_step_adamw & args,
+        device       float * x,
+        device const float * g,
+        device       float * g_m,
+        device       float * g_v,
+        device const float * pars,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const float alpha  = pars[0];
+    const float beta1  = pars[1];
+    const float beta2  = pars[2];
+    const float eps    = pars[3];
+    const float wd     = pars[4];
+    const float beta1h = pars[5];
+    const float beta2h = pars[6];
+
+    const float gi = g[gid];
+    const float gmi = g_m[gid] * beta1 +      gi * (1.0f - beta1);
+    const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
+
+    g_m[gid] = gmi;
+    g_v[gid] = gvi;
+
+    const float mh =      gmi * beta1h;
+    const float vh = sqrt(gvi * beta2h) + eps;
+
+    x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
+}