]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix 1D im2col, add tests (ggml/993)
authorJohannes Gäßler <redacted>
Fri, 18 Oct 2024 07:24:44 +0000 (09:24 +0200)
committerGeorgi Gerganov <redacted>
Wed, 23 Oct 2024 13:50:02 +0000 (16:50 +0300)
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/im2col.cu
tests/test-backend-ops.cpp

index 1338bd45836bb58e7841b6184ce52e6333c34241..fa280b529bcb8d9546358cb00a024b412ceafad8 100644 (file)
@@ -3141,7 +3141,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ROPE:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
-            return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
index 16463ab0fb683cf8e24095bc9eb3209e24e7d933..86a54e42bb7e64d4d2b574b9b51e7ba4b20c9ee5 100644 (file)
@@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-    const int64_t batch = src1->ne[3];
-    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch        = src1->ne[is_2D ? 3 : 2];
+    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     if(dst->type == GGML_TYPE_F16) {
         im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
index e087f7ba5be1667a591fdd18e0836dbc0270e8c8..7e769a91a1944c861c50faa069ccd420fd11982e 100644 (file)
@@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
-    // test cases for 1D im2col
+    // im2col 1D
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    for (int s0 : {1, 3}) {
+        for (int p0 : {0, 3}) {
+            for (int d0 : {1, 3}) {
+                test_cases.emplace_back(new test_im2col(
+                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
+                    s0, 0, p0, 0, d0, 0, false));
+            }
+        }
+    }
+
+    // im2col 2D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+    for (int s0 : {1, 3}) {
+        for (int s1 : {1, 3}) {
+            for (int p0 : {0, 3}) {
+                for (int p1 : {0, 3}) {
+                    for (int d0 : {1, 3}) {
+                        for (int d1 : {1, 3}) {
+                            test_cases.emplace_back(new test_im2col(
+                                GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
+                                s0, s1, p0, p1, d0, d1, true));
+                        }
+                    }
+                }
+            }
+        }
+    }
 
-    // test cases for 2D im2col
+    // extra tests for im2col 2D
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));