]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add kernel_get_rows_i32
authorGeorgi Gerganov <redacted>
Wed, 3 Jan 2024 09:35:46 +0000 (11:35 +0200)
committerGeorgi Gerganov <redacted>
Wed, 3 Jan 2024 12:43:51 +0000 (14:43 +0200)
ggml-ci

ggml-metal.m
ggml-metal.metal

index 7a369b55e36282e2de083db6b860b9f38d4e059d..7aa92c14c9cdcc326a40c72afbfb1ab58ad36367 100644 (file)
@@ -87,6 +87,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q4_K);
     GGML_METAL_DECL_KERNEL(get_rows_q5_K);
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
+    GGML_METAL_DECL_KERNEL(get_rows_i32);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(group_norm);
     GGML_METAL_DECL_KERNEL(norm);
@@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q4_K);
         GGML_METAL_ADD_KERNEL(get_rows_q5_K);
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
+        GGML_METAL_ADD_KERNEL(get_rows_i32);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(group_norm);
         GGML_METAL_ADD_KERNEL(norm);
@@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q4_K);
     GGML_METAL_DEL_KERNEL(get_rows_q5_K);
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
+    GGML_METAL_DEL_KERNEL(get_rows_i32);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(group_norm);
     GGML_METAL_DEL_KERNEL(norm);
@@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
                                 case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
                                 case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
                                 case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
+                                case GGML_TYPE_I32:  [encoder setComputePipelineState:ctx->pipeline_get_rows_i32];  break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
index 9aa7b502a9ea094e386be638cc6827d9188600a6..a7d3f9efa5787a2526fd5af424bcf2059fe731e3 100644 (file)
@@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
     }
 }
 
+kernel void kernel_get_rows_i32(
+        device const  void * src0,
+        device const  char * src1,
+        device     int32_t * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+
 #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
 #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
 #define BLOCK_SIZE_K 32