From: Georgi Gerganov Date: Fri, 2 Jun 2023 12:46:59 +0000 (+0300) Subject: ggml : add ggml_conv_2d_sk_p0(), ggml_win_part(), ggml_win_unpart() X-Git-Tag: upstream/0.0.1642~1418 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=758471b22630cc037244dbe1961a87097988aa75;p=pkg%2Fggml%2Fsources%2Fggml ggml : add ggml_conv_2d_sk_p0(), ggml_win_part(), ggml_win_unpart() --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index e0c50c29..dce092a1 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -318,9 +318,12 @@ extern "C" { GGML_OP_CLAMP, GGML_OP_CONV_1D_S1_PH, GGML_OP_CONV_1D_S2_PH, + GGML_OP_CONV_2D_SK_P0, GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, @@ -963,6 +966,19 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_flash_attn( struct ggml_context * ctx, struct ggml_tensor * q, @@ -978,6 +994,26 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + // Mapping operations typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); diff --git a/src/ggml.c b/src/ggml.c index bff058cc..18a2007b 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3547,16 +3547,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CLAMP", "CONV_1D_S1_PH", "CONV_1D_S2_PH", + "CONV_2D_SK_P0", "FLASH_ATTN", "FLASH_FF", + "WIN_PART", + "WIN_UNPART", "MAP_UNARY", "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); - +static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3609,15 +3611,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "clamp(x)", "conv_1d_s1_ph(x)", "conv_1d_s2_ph(x)", + "conv_2d_sk_p0(x)", "flash_attn(x)", "flash_ff(x)", + "win_part(x)", + "win_unpart(x)", "f(x)", "f(x,y)", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); +static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -6447,6 +6452,34 @@ struct ggml_tensor * ggml_conv_1d_s2_ph( return result; } +// ggml_conv_2d_sk_p0 + +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(b->ne[0] % a->ne[0] == 0); + GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_CONV_2D_SK_P0; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_flash_attn struct ggml_tensor * ggml_flash_attn( @@ -6511,6 +6544,90 @@ struct ggml_tensor * ggml_flash_ff( return result; } +// ggml_win_part + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + + ((int32_t *) b->data)[0] = npx; + ((int32_t *) b->data)[1] = npy; + ((int32_t *) b->data)[2] = w; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_WIN_PART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} + +// ggml_win_unpart + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ((int32_t *) b->data)[0] = w; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_WIN_UNPART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} + // ggml_map_unary struct ggml_tensor * ggml_map_unary_impl_f32( @@ -12101,6 +12218,143 @@ static void ggml_compute_forward_conv_1d_s2_ph( } } +// ggml_compute_forward_conv_2d_sk_p0 + +static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + //const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + //const int nb01 = src0->nb[1]; + //const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + //const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + //const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk0 = ne00; + const int nk1 = ne01; + + // size of the convolution row - the kernel size unrolled across all channels + // round-up so it is more suitable for SIMD + const int ew0 = ggml_up32(nk0*nk1*ne02); + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i12 = 0; i12 < ne12; i12++) { + const float * const src = (float *)((char *) src1->data + i12*nb12); + ggml_fp16_t * dst_data = wdata; + + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + for (int ik1 = 0; ik1 < nk1; ik1++) { + for (int ik0 = 0; ik0 < nk0; ik0++) { + dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = + GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]); + } + } + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i2 = ip0; i2 < ip1; i2++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2); + + for (int i1 = 0; i1 < ne1; ++i1) { + for (int i0 = 0; i0 < ne0; ++i0) { + ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, + (ggml_fp16_t *) ((char *) src0->data + i2*nb03), + (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0); + } + } + } +} + +static void ggml_compute_forward_conv_2d_sk_p0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst); + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_attn static void ggml_compute_forward_flash_attn_f32( @@ -12787,6 +13041,146 @@ static void ggml_compute_forward_flash_ff( } } +// ggml_compute_forward_win_part + +static void ggml_compute_forward_win_part_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + //const int64_t ne03 = src0->ne[3]; + UNUSED(ne00); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; + const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; + const int32_t w = ((const int32_t *)(opt0->data))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void ggml_compute_forward_win_part( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, src0, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_unpart + +static void ggml_compute_forward_win_unpart_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + //const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + + const int32_t w = ((const int32_t *)(opt0->data))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void ggml_compute_forward_win_unpart( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -13070,17 +13464,29 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_CONV_2D_SK_P0: + { + ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_FLASH_ATTN: { - int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; + const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); } break; + case GGML_OP_WIN_PART: + { + ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor); + } break; + case GGML_OP_WIN_UNPART: + { + ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor); + } break; case GGML_OP_MAP_UNARY: { const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data); @@ -13767,6 +14173,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_2D_SK_P0: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_FLASH_ATTN: { GGML_ASSERT(false); // not supported @@ -13775,6 +14185,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: { @@ -14299,6 +14711,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) GGML_ASSERT(false); } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_2D_SK_P0: + { + node->n_tasks = n_threads; + + GGML_ASSERT(node->src1->ne[3] == 1); + + const int64_t ne00 = node->src0->ne[0]; // W + const int64_t ne01 = node->src0->ne[1]; // H + const int64_t ne02 = node->src0->ne[2]; // C + const int64_t ne03 = node->src0->ne[3]; // N + + const int64_t ne10 = node->src1->ne[0]; // W + const int64_t ne11 = node->src1->ne[1]; // H + const int64_t ne12 = node->src1->ne[2]; // C + + const int64_t nk = ne00*ne01; + + UNUSED(ne02); + UNUSED(ne03); + UNUSED(nk); + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)* (ne10*ne11*ne12); + } else { + GGML_ASSERT(false); + } + work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_ATTN: @@ -14339,6 +14786,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) work_size = MAX(work_size, cur); } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: {