]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix conv_2d batch mode (ggml/737)
authorbssrdf <redacted>
Tue, 20 Feb 2024 19:17:09 +0000 (14:17 -0500)
committerGeorgi Gerganov <redacted>
Thu, 22 Feb 2024 13:12:32 +0000 (15:12 +0200)
Co-authored-by: bssrdf <redacted>
ggml.c

diff --git a/ggml.c b/ggml.c
index 4ee2c5e11a0025a1509a10cef75fa0375d4cd252..86e3bc2ea3c2568f6b107dcac90a06788798ed4d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -5629,7 +5629,9 @@ struct ggml_tensor * ggml_conv_2d(
                 ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
                 ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC,IC, KH, KW] => [OC, IC * KH * KW]
 
-    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
+    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
+
 
     return result;
 }