]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : copy kernels for quant to F32/F16 conversions (llama/12017)
authorGian-Carlo Pascutto <redacted>
Tue, 25 Feb 2025 09:27:58 +0000 (10:27 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
metal: use dequantize_q templates

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 087e7f58149f6a537cedca454462c46c805ca2b0..c550142a7d07c635c3e048973826f3899a2c1fb2 100644 (file)
@@ -407,6 +407,16 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
     GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                             default:
                                 return false;
                         }
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        switch (op->type) {
+                            case GGML_TYPE_F32:
+                            case GGML_TYPE_F16:
+                                return true;
+                            default:
+                                return false;
+                        }
                     default:
                         return false;
                 };
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             {
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
-                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
                 id<MTLComputePipelineState> pipeline = nil;
 
                 switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
                             switch (dstt) {
                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
-                                default: GGML_ASSERT(false && "not implemented");
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q8_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
                             };
                         } break;
                     default: GGML_ABORT("not implemented");
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
             } break;
         case GGML_OP_SET:
             {
index 83e7ac9f411ef3f0c74fcb0aade6d691641d3a77..d092a16906155085f2aa594efcdb0d61218edcb9 100644 (file)
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
     }
 }
 
+template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_cpy_q_f32(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+        T4x4 temp;
+        dequantize_func(src_data + i00/nl, i00%nl, temp);
+        dst_data[i00] = temp;
+    }
+}
+
+typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
+
 kernel void kernel_concat(
     constant ggml_metal_kargs_concat & args,
     device  const char * src0,