]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add mrope kernel for metal (llama/13457)
authorXuan-Son Nguyen <redacted>
Mon, 12 May 2025 08:29:13 +0000 (10:29 +0200)
committerGeorgi Gerganov <redacted>
Tue, 13 May 2025 10:02:19 +0000 (13:02 +0300)
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal

index ea8cf45863af97becf577fdc6e76e8d0de53c2f7..17eab976f3ad114eaeb9ce108ecd635ce12fddc9 100644 (file)
@@ -207,6 +207,10 @@ typedef struct {
     float    attn_factor;
     float    beta_fast;
     float    beta_slow;
+    int32_t  sect_0;
+    int32_t  sect_1;
+    int32_t  sect_2;
+    int32_t  sect_3;
 } ggml_metal_kargs_rope;
 
 typedef struct {
index 943f0fe722a2f5d78a27c2f1c3439be900fbafac..576f9581bdaee40c4bcb7c1677ffdfaaf6b1eb00 100644 (file)
@@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,            mul_mm_id_iq4_xs_f16,            has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                   rope_norm_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                   rope_norm_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,                  rope_multi_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,                  rope_multi_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,                 rope_vision_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,                 rope_vision_f16,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                   rope_neox_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                   rope_neox_f16,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                      im2col_f16,                      true);
@@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ROPE:
-            {
-                const int mode = ((const int32_t *) op->op_params)[2];
-                if (mode & GGML_ROPE_TYPE_MROPE) {
-                    return false;
-                }
-                if (mode & GGML_ROPE_TYPE_VISION) {
-                    return false;
-                }
-                return true;
-            }
+            return true;
         case GGML_OP_IM2COL:
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
@@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
             } break;
         case GGML_OP_ROPE:
             {
+
                 // make sure we have one or more position id(ne10) per token(ne02)
                 GGML_ASSERT(ne10 % ne02 == 0);
                 GGML_ASSERT(ne10 >= ne02);
@@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
                 memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
                 memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
 
-                const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+                const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX;
+                const bool is_mrope  = mode & GGML_ROPE_TYPE_MROPE;
+                const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+                // mrope
+                const int sect_0 = ((const int32_t *) dst->op_params)[11];
+                const int sect_1 = ((const int32_t *) dst->op_params)[12];
+                const int sect_2 = ((const int32_t *) dst->op_params)[13];
+                const int sect_3 = ((const int32_t *) dst->op_params)[14];
 
                 id<MTLComputePipelineState> pipeline = nil;
 
-                if (!is_neox) {
+                if (is_neox) {
                     switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else if (is_mrope && !is_vision) {
+                    GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else if (is_vision) {
+                    GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
                         default: GGML_ABORT("fatal error");
                     };
                 } else {
                     switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
                         default: GGML_ABORT("fatal error");
                     };
                 }
@@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
                     /*.attn_factor =*/ attn_factor,
                     /*.beta_fast   =*/ beta_fast,
                     /*.beta_slow   =*/ beta_slow,
+                    /* sect_0      =*/ sect_0,
+                    /* sect_1      =*/ sect_1,
+                    /* sect_2      =*/ sect_2,
+                    /* sect_3      =*/ sect_3,
                 };
 
                 [encoder setComputePipelineState:pipeline];
index a808e79c2275224c6c3385947cd398f27f1a11ee..9cfddf4503abe704a977b9350b6937ac99ccba67 100644 (file)
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
     }
 }
 
+template<typename T>
+kernel void kernel_rope_multi(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+
+            // mrope theta calculations
+            // note: the rest is the same as kernel_rope_neox
+            const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
+            const int sec_w01   = args.sect_0 + args.sect_1;               // end of section 1
+            const int sec_w012  = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
+            const int sector    = ic % sect_dims;
+
+            float theta_base;
+            if (sector < args.sect_0) {
+                theta_base = (float) pos[i2];
+            } else if (sector < sec_w01) {
+                theta_base = (float) pos[i2 + args.ne02];
+            } else if (sector < sec_w012) {
+                theta_base = (float) pos[i2 + args.ne02 * 2];
+            } else {
+                theta_base = (float) pos[i2 + args.ne02 * 3];
+            }
+            // end of mrope
+
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims/2];
+
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template<typename T>
+kernel void kernel_rope_vision(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
+            const int ic = i0/2;
+
+            // mrope theta calculations (only support 2 dimensions)
+            const int sect_dims = args.sect_0 + args.sect_1;
+            const int sector    = ic % sect_dims;
+
+            float p;
+            float theta_base;
+            if (sector < args.sect_1) {
+                p = (float) sector;
+                theta_base = (float) pos[i2];
+            } else {
+                p = (float) sector - args.sect_0;
+                theta_base = (float) pos[i2 + args.ne02];
+            }
+
+            const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
+            // end of mrope
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims]; // different from kernel_rope_multi
+
+            dst_data[0]           = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
 typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
 typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
+typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
 
 template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
 template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
 template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
 template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
 
+template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
+template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
+
+template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
+template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
+
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,