]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add im2col F32 dst support (llama/5132)
authorGeorgi Gerganov <redacted>
Wed, 31 Jan 2024 13:35:41 +0000 (15:35 +0200)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:46 +0000 (09:55 +0200)
ggml-metal.m
ggml-metal.metal

index f87859552816b0f90668dd696aebf666169a937f..5260ed827702694d22c569274812c8377475cbf8 100644 (file)
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
     GGML_METAL_KERNEL_TYPE_ALIBI_F32,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
+            return true;
+        case GGML_OP_POOL_1D:
+        case GGML_OP_POOL_2D:
+            return false;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
                     {
                         GGML_ASSERT(src0->type == GGML_TYPE_F16);
                         GGML_ASSERT(src1->type == GGML_TYPE_F32);
-                        GGML_ASSERT( dst->type == GGML_TYPE_F16);
+                        GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 
                         const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
                         const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
                         const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
                         const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
                         const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+
                         const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 
                         const int32_t N  = src1->ne[is_2D ? 3 : 2];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
 
                         id<MTLComputePipelineState> pipeline = nil;
 
-                        switch (src0->type) {
-                            case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+                        switch (dst->type) {
+                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
                             case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
                             default: GGML_ASSERT(false);
                         };
index 2614d82e8b9b4d4a246fc6b0963f10c52c4b3ea4..efed6ad465e78d4fba26c293214ac980b131e7b3 100644 (file)
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
-kernel void kernel_im2col_f16(
+typedef void (im2col_t)(
         device const float * x,
-        device       half * dst,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col(
+        device const float * x,
+        device        char * dst,
         constant   int32_t & ofs0,
         constant   int32_t & ofs1,
         constant   int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
         (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
         (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
 
+    device T * pdst = (device T *) (dst);
+
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        dst[offset_dst] = 0.0f;
+        pdst[offset_dst] = 0.0f;
     } else {
         const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
-        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
     }
 }
 
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
+
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,