From: bssrdf Date: Tue, 20 Feb 2024 19:17:09 +0000 (-0500) Subject: ggml : fix conv_2d batch mode (ggml/737) X-Git-Tag: upstream/1.7.4~977 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d352dbd163019d12b2234e60b7d161fe4192aee9;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml : fix conv_2d batch mode (ggml/737) Co-authored-by: bssrdf --- diff --git a/ggml.c b/ggml.c index 4ee2c5e1..86e3bc2e 100644 --- a/ggml.c +++ b/ggml.c @@ -5629,7 +5629,9 @@ struct ggml_tensor * ggml_conv_2d( ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] + return result; }