]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add f16 support
authorGeorgi Gerganov <redacted>
Tue, 6 Jun 2023 17:16:57 +0000 (20:16 +0300)
committerGeorgi Gerganov <redacted>
Tue, 6 Jun 2023 17:21:56 +0000 (20:21 +0300)
ggml-metal.m
ggml-metal.metal
llama.cpp

index d721ac688c707bfcfe8ce1c12db09cd06973ac6e..0953af6a497c9a6a379692a3f09192f4680a50c5 100644 (file)
@@ -47,10 +47,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(rms_norm);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(rope);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(rms_norm);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(rope);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
 
                         // use custom matrix x vector kernel
                         switch (src0t) {
+                            case GGML_TYPE_F16:
+                                {
+                                    GGML_ASSERT(ne02 == ne12);
+
+                                    nth0 = 64;
+                                    nth1 = 1;
+                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                } break;
                             case GGML_TYPE_Q4_0:
                                 {
                                     GGML_ASSERT(ne02 == 1);
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
                                     nth1 = 4;
                                     [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
                                 } break;
-                            case GGML_TYPE_F16:
-                                {
-                                    GGML_ASSERT(ne02 == ne12);
-
-                                    nth0 = 32;
-                                    nth1 = 1;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
-                                } break;
                             default: GGML_ASSERT(false && "not implemented");
                         };
 
@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
                     }
 
                     switch (src0->type) {
+                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                         case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                         default: GGML_ASSERT(false && "not implemented");
                     }
index 4bedc8ea45e573beb8790f5688ee793124cb8bc0..a359bebe2d79987a64d91363bdd4ae88f6529c88 100644 (file)
@@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf(
     }
 }
 
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    for (int j = 0; j < ne00; j++) {
+        dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
+    }
+}
+
 kernel void kernel_get_rows_q4_0(
         device const  void * src0,
         device const   int * src1,
index 70341d04f91b0a255014960c173e63d50da63b28..73f686003dc2bf5ba6413b4aed73fad36d618f46 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -961,7 +961,6 @@ static void llama_model_load_internal(
     model.hparams = ml->file_loaders.at(0)->hparams;
     llama_file_version file_version = ml->file_loaders.at(0)->file_version;
     auto & hparams = model.hparams;
-    uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
 
     {
         switch (hparams.n_layer) {
@@ -975,6 +974,8 @@ static void llama_model_load_internal(
         hparams.n_ctx = n_ctx;
     }
 
+    const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+
     {
         fprintf(stderr, "%s: format     = %s\n",  __func__, llama_file_version_name(file_version));
         fprintf(stderr, "%s: n_vocab    = %u\n",  __func__, hparams.n_vocab);