]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix batch for ggml_conv_2d (#528)
authorskirodev <redacted>
Thu, 28 Sep 2023 21:10:45 +0000 (05:10 +0800)
committerGitHub <redacted>
Thu, 28 Sep 2023 21:10:45 +0000 (00:10 +0300)
src/ggml.c

index 5c794089484f0c7b4d5cf7be07aeb92195b169bd..2a193849949c8164c814ae74fb54971ce0fde4c5 100644 (file)
@@ -13935,20 +13935,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i12 = 0; i12 < ne12; i12++) {
-                const float * const src = (float *)((char *) src1->data + i12*nb12);
-                ggml_fp16_t * dst_data = wdata;
-
-                for (int i1 = 0; i1 < ne1; i1++) {
-                    for (int i0 = 0; i0 < ne0; i0++) {
-                        for (int ik1 = 0; ik1 < nk1; ik1++) {
-                            for (int ik0 = 0; ik0 < nk0; ik0++) {
-                                const int idx0 = i0*s0 + ik0*d0 - p0;
-                                const int idx1 = i1*s1 + ik1*d1 - p1;
-
-                                if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
-                                    dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
-                                        GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
+            for (int i13 = 0; i13 < ne13; i13++) {
+                for (int i12 = 0; i12 < ne12; i12++) {
+                    const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
+                    ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
+
+                    for (int i1 = 0; i1 < ne1; i1++) {
+                        for (int i0 = 0; i0 < ne0; i0++) {
+                            for (int ik1 = 0; ik1 < nk1; ik1++) {
+                                for (int ik0 = 0; ik0 < nk0; ik0++) {
+                                    const int idx0 = i0*s0 + ik0*d0 - p0;
+                                    const int idx1 = i1*s1 + ik1*d1 - p1;
+
+                                    if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
+                                        dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
+                                            GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
+                                    }
                                 }
                             }
                         }