From: bssrdf Date: Tue, 20 Feb 2024 19:17:09 +0000 (-0500) Subject: ggml : fix conv_2d batch mode (#737) X-Git-Tag: upstream/0.0.1642~932 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=25cc1d5d8329f55c8fc5202259046c73cc298e4e;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix conv_2d batch mode (#737) Co-authored-by: bssrdf --- diff --git a/src/ggml.c b/src/ggml.c index 4ee2c5e1..86e3bc2e 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5629,7 +5629,9 @@ struct ggml_tensor * ggml_conv_2d( ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] + return result; }