GGML_METAL_DECL_KERNEL(get_rows_q4_K);
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
+ GGML_METAL_DECL_KERNEL(get_rows_i32);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
+ GGML_METAL_ADD_KERNEL(get_rows_i32);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
+ GGML_METAL_DEL_KERNEL(get_rows_i32);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
+ case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
default: GGML_ASSERT(false && "not implemented");
}
}
}
+kernel void kernel_get_rows_i32(
+ device const void * src0,
+ device const char * src1,
+ device int32_t * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32