return res;
}
+
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_OPT_STEP_SGD);
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+ if (res) {
+ return res;
+ }
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+ return res;
+}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
+ case GGML_OP_OPT_STEP_SGD:
+ {
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
+ } break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
return 1;
}
+
+int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
+
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
+
+ const int64_t np = ggml_nelements(op->src[0]);
+ ggml_metal_kargs_opt_step_sgd args = {
+ /*.np =*/ np,
+ };
+
+ int ida = 0;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+ const int64_t n = (np + nth - 1) / nth;
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
+
+ return 1;
+}
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
+
+kernel void kernel_opt_step_sgd_f32(
+ constant ggml_metal_kargs_opt_step_sgd & args,
+ device float * x,
+ device const float * g,
+ device const float * pars,
+ uint gid[[thread_position_in_grid]]) {
+
+ if (gid >= args.np) {
+ return;
+ }
+
+ x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
+}