]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal: add support for opt_step_sgd (llama/16539)
authorSam/Samuel <redacted>
Mon, 13 Oct 2025 08:25:02 +0000 (16:25 +0800)
committerGeorgi Gerganov <redacted>
Tue, 14 Oct 2025 19:07:44 +0000 (22:07 +0300)
* metal: add support for opt_step_sgd

* add newline to pass EditorConfig check

src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.metal

index 335d5848e290c48d90ef4c282b049433a6cbe54f..866cd2da585761d14bef55bbac68e2348efe57e9 100644 (file)
@@ -1519,3 +1519,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_
 
     return res;
 }
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_OPT_STEP_SGD);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
index 283e70fa7910971c88998ff5c4f9255357845d83..28ae2e1765146f7e450b75b27a2763f88abec2da 100644 (file)
@@ -136,6 +136,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
 
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
index e38e70768040a1b8b212f0b2c04aa2a04f3a099a..fc5083043f7c97d2b8a14c06ef46f8754c6ad6bf 100644 (file)
@@ -800,6 +800,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 };
             }
         case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             return has_simdgroup_reduction;
         default:
             return false;
index c4c9f0a7f6aefc64fb86790642f5078ce499cb5f..a448c14f66b63616c960c712c6a4fb4f26b6ca07 100644 (file)
@@ -781,4 +781,8 @@ typedef struct {
     int64_t  np;
 } ggml_metal_kargs_opt_step_adamw;
 
+typedef struct {
+    int64_t  np;
+} ggml_metal_kargs_opt_step_sgd;
+
 #endif // GGML_METAL_IMPL
index c01c0b181e8f56ca3961077660babb1934602c62..a61ea8fb5a7b37202daa13d78b1c0616f451a78a 100644 (file)
@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
             } break;
+        case GGML_OP_OPT_STEP_SGD:
+            {
+                n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
+            } break;
        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
 
     return 1;
 }
+
+int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
+
+    const int64_t np = ggml_nelements(op->src[0]);
+    ggml_metal_kargs_opt_step_sgd args = {
+        /*.np =*/ np,
+    };
+
+    int ida = 0;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+
+    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+    const int64_t n = (np + nth - 1) / nth;
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
+
+    return 1;
+}
index 6641cf5dfcb52ad65fd98ea7839f2c562fcf80ca..f352738698beb9b00dcf09ea4080be41ffad569c 100644 (file)
@@ -80,6 +80,7 @@ int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
 
 #ifdef __cplusplus
 }
index 780d6a97350ebd5fc5d0834fc06b686e44b7c144..74a9aa998837cdc84213632d045b85b7c32fee71 100644 (file)
@@ -8806,3 +8806,17 @@ kernel void kernel_opt_step_adamw_f32(
 
     x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
 }
+
+kernel void kernel_opt_step_sgd_f32(
+        constant    ggml_metal_kargs_opt_step_sgd & args,
+        device       float * x,
+        device const float * g,
+        device const float * pars,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
+}