From: Guillaume Wenzek Date: Fri, 29 Dec 2023 17:07:03 +0000 (+0100) Subject: ggml : extend ggml_get_rows, ggml_repeat, ggml_concat (#639) X-Git-Tag: upstream/0.0.1642~1140 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=fca1caafea7de9fbd7efc733b9818f9cf2da3050;p=pkg%2Fggml%2Fsources%2Fggml ggml : extend ggml_get_rows, ggml_repeat, ggml_concat (#639) * add more int ops * ggml_compute_forward_dup_bytes * add tests * PR comments * tests : minor indentations --------- Co-authored-by: Georgi Gerganov --- diff --git a/src/ggml.c b/src/ggml.c index a9e1ea9b..4d8891c3 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -4766,8 +4766,11 @@ struct ggml_tensor * ggml_get_rows( } // TODO: implement non F32 return - //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + enum ggml_type type = GGML_TYPE_F32; + if (a->type == GGML_TYPE_I32) { + type = a->type; + } + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = GGML_OP_GET_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6938,14 +6941,165 @@ static void ggml_compute_forward_dup_f32( } } -static void ggml_compute_forward_dup( +// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void ggml_compute_forward_dup_bytes( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(src0->type == dst->type); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { ggml_compute_forward_dup_same_cont(params, src0, dst); return; } + + GGML_TENSOR_UNARY_OP_LOCALS; + + const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, type_size); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (src0->type == dst->type) { + ggml_compute_forward_dup_bytes(params, src0, dst); + return; + } + switch (src0->type) { case GGML_TYPE_F16: { @@ -8404,10 +8558,12 @@ static void ggml_compute_forward_repeat( struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_I16: { ggml_compute_forward_repeat_f16(params, src0, dst); } break; case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_repeat_f32(params, src0, dst); } break; @@ -8550,6 +8706,7 @@ static void ggml_compute_forward_concat( struct ggml_tensor* dst) { switch (src0->type) { case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_concat_f32(params, src0, src1, dst); } break; @@ -10674,6 +10831,7 @@ static void ggml_compute_forward_get_rows( ggml_compute_forward_get_rows_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_get_rows_f32(params, src0, src1, dst); } break; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 63b7fac0..754d6376 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -312,6 +312,14 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c) target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) +# +# test-dup + +set(TEST_TARGET test-dup) +add_executable(${TEST_TARGET} ${TEST_TARGET}.c) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) + # # test-rel-pos diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b115299c..0c749813 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -58,6 +58,9 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m int64_t hist[16]; ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist); ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); + } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { + // This is going to create some weird integers though. + ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor)); } else { GGML_ASSERT(false); } @@ -87,8 +90,13 @@ static std::vector tensor_to_float(const ggml_tensor * t) { tv.push_back(*(float *) &buf[i]); } else if (t->type == GGML_TYPE_I32) { tv.push_back((float)*(int32_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); } else if (quantized) { - tt.to_float(&buf[i], vq.data(), bs); + std::vector vq(ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); tv.insert(tv.end(), vq.begin(), vq.end()); } else { GGML_ASSERT(false); @@ -661,17 +669,26 @@ struct test_repeat : public test_case { struct test_dup : public test_case { const ggml_type type; const std::array ne; + const std::array permute; + bool _use_permute; std::string vars() override { - return VARS_TO_STR2(type, ne); + std::string v = VARS_TO_STR2(type, ne); + if (_use_permute) v += "," + VAR_TO_STR(permute); + return v; } test_dup(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 1}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 1}, + std::array permute = {0, 0, 0, 0}) + : type(type), ne(ne), permute(permute), + _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); + if (_use_permute) { + src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); + } ggml_tensor * out = ggml_dup(ctx, src); return out; } @@ -1450,14 +1467,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } + for (int b : {1, 7}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v)); + } + } test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 10, 10, 10}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 10, 10, 10}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_dup()); + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); + test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (ggml_type type : all_types) { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1})); @@ -1565,7 +1594,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_alibi()); test_cases.emplace_back(new test_im2col()); - test_cases.emplace_back(new test_concat()); + test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); diff --git a/tests/test-dup.c b/tests/test-dup.c new file mode 100644 index 00000000..afc887b7 --- /dev/null +++ b/tests/test-dup.c @@ -0,0 +1,110 @@ +#include "ggml/ggml.h" + +#include +#include + +void arange(struct ggml_tensor* tensor) { + GGML_ASSERT(ggml_is_contiguous(tensor)); + for (int i = 0; i < ggml_nelements(tensor); ++i) { + ggml_set_i32_1d(tensor, i, i); + } +} + +void dup_to(struct ggml_tensor* src, struct ggml_tensor* dst) { + GGML_ASSERT(dst->op == GGML_OP_VIEW); + GGML_ASSERT(ggml_nelements(src) == ggml_nelements(dst)); + dst->op = GGML_OP_DUP; + dst->src[0] = src; +} + +bool can_dup(enum ggml_type src_type, enum ggml_type dst_type) { + if (src_type == dst_type) return true; + if (src_type == GGML_TYPE_F32 && ggml_internal_get_type_traits(dst_type).from_float) return true; + if (dst_type == GGML_TYPE_F32 && ggml_internal_get_type_traits(src_type).to_float) return true; + + return false; +} + +int main(int argc, const char ** argv) { + struct ggml_init_params params = { + .mem_size = 128*1024*1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + enum ggml_type type[4] = {GGML_TYPE_I16, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_F32}; + for (int i = 0; i < 4; ++i) { + enum ggml_type src_type = type[i]; + for (int j = 0; j < 4; ++j) { + enum ggml_type dst_type = type[j]; + if (!can_dup(src_type, dst_type)) continue; + printf("Testing dup on %s -> %s copy\n", ggml_type_name(src_type), ggml_type_name(dst_type)); + + struct ggml_context * ctx = ggml_init(params); + + struct ggml_tensor * src = ggml_new_tensor_2d(ctx, src_type, 10, 11); + arange(src); + struct ggml_tensor * dst = ggml_new_tensor_2d(ctx, dst_type, 10, 11); + ggml_set_i32(dst, 0); + + // 2nd-row: [20, 21, ..., 29] + struct ggml_tensor * src_cont = ggml_view_1d(ctx, src, 10, src->nb[1] * 2); + + // 3rd-col: [03, 13, ..., 93] + struct ggml_tensor * src_stride = ggml_view_2d(ctx, src, 1, 10, src->nb[1], src->nb[0] * 3); + + struct ggml_tensor * dst_cont_1 = ggml_view_1d(ctx, dst, 10, dst->nb[1] * 5); // 5nd-row + struct ggml_tensor * dst_cont_2 = ggml_view_1d(ctx, dst, 10, dst->nb[1] * 6); // 6rd-row + + struct ggml_tensor * dst_stride_1 = ggml_view_2d(ctx, dst, 1, 10, dst->nb[1], dst->nb[0] * 7); // 7th-col + struct ggml_tensor * dst_stride_2 = ggml_view_2d(ctx, dst, 1, 10, dst->nb[1], dst->nb[0] * 8); // 8th-col + + struct ggml_cgraph * gf = ggml_new_graph(ctx); + + dup_to(src_cont, dst_cont_1); + dup_to(src_stride, dst_cont_2); + dup_to(src_cont, dst_stride_1); + dup_to(src_stride, dst_stride_2); + + ggml_build_forward_expand(gf, dst_cont_1); + ggml_build_forward_expand(gf, dst_cont_2); + ggml_build_forward_expand(gf, dst_stride_1); + ggml_build_forward_expand(gf, dst_stride_2); + + ggml_graph_compute_with_ctx(ctx, gf, 1); + + // src_cont -> dst_cont_1 + GGML_ASSERT(ggml_get_i32_1d(dst, 49) == 0); + GGML_ASSERT(ggml_get_i32_1d(dst, 50) == 20); + GGML_ASSERT(ggml_get_i32_1d(dst, 51) == 21); + GGML_ASSERT(ggml_get_i32_1d(dst, 52) == 22); + GGML_ASSERT(ggml_get_i32_1d(dst, 59) == 29); + + // src_stride -> dst_cont_2 + GGML_ASSERT(ggml_get_i32_1d(dst, 60) == 3); + GGML_ASSERT(ggml_get_i32_1d(dst, 61) == 13); + GGML_ASSERT(ggml_get_i32_1d(dst, 62) == 23); + GGML_ASSERT(ggml_get_i32_1d(dst, 69) == 93); + GGML_ASSERT(ggml_get_i32_1d(dst, 70) == 0); + + // src_cont -> dst_stride_1 + GGML_ASSERT(ggml_get_i32_1d(dst, 6) == 0); + GGML_ASSERT(ggml_get_i32_1d(dst, 7) == 20); + GGML_ASSERT(ggml_get_i32_1d(dst, 17) == 21); + GGML_ASSERT(ggml_get_i32_1d(dst, 27) == 22); + GGML_ASSERT(ggml_get_i32_1d(dst, 97) == 29); + GGML_ASSERT(ggml_get_i32_1d(dst, 107) == 0); + + // src_stride -> dst_stride_2 + GGML_ASSERT(ggml_get_i32_1d(dst, 8) == 03); + GGML_ASSERT(ggml_get_i32_1d(dst, 18) == 13); + GGML_ASSERT(ggml_get_i32_1d(dst, 28) == 23); + GGML_ASSERT(ggml_get_i32_1d(dst, 98) == 93); + GGML_ASSERT(ggml_get_i32_1d(dst, 108) == 0); + + ggml_free(ctx); + } + } + + return 0; +}